def forward(self, inputs, outputs):
     """See modeling.detector.CollectAndDistributeFpnRpnProposals for
     inputs/outputs documentation.
     """
     # inputs is
     # [rpn_rois_fpn2, ..., rpn_rois_fpn6,
     #  rpn_roi_probs_fpn2, ..., rpn_roi_probs_fpn6]
     # If training with Faster R-CNN, then inputs will additionally include
     #  + [roidb, im_info]
     rois = collect(inputs, self._train)
     if self._train:
         # During training we reuse the data loader code. We populate roidb
         # entries on the fly using the rois generated by RPN.
         # im_info: [[im_height, im_width, im_scale], ...]
         im_info = inputs[-1].data
         im_scales = im_info[:, 2]
         roidb = blob_utils.deserialize(inputs[-2].data)
         # For historical consistency with the original Faster R-CNN
         # implementation we are *not* filtering crowd proposals.
         # This choice should be investigated in the future (it likely does
         # not matter).
         json_dataset.add_proposals(roidb, rois, im_scales, crowd_thresh=0)
         # Compute training labels for the RPN proposals; also handles
         # distributing the proposals over FPN levels
         output_blob_names = roi_data.fast_rcnn.get_fast_rcnn_blob_names()
         blobs = {k: [] for k in output_blob_names}
         roi_data.fast_rcnn.add_fast_rcnn_blobs(blobs, im_scales, roidb)
         for i, k in enumerate(output_blob_names):
             blob_utils.py_op_copy_blob(blobs[k], outputs[i])
     else:
         # For inference we have a special code path that avoids some data
         # loader overhead
         distribute(rois, None, outputs, self._train)
예제 #2
0
 def forward(self, inputs, outputs):
     """See modeling.detector.CollectAndDistributeFpnRpnProposals for
     inputs/outputs documentation.
     """
     # inputs is
     # [rpn_rois_fpn2, ..., rpn_rois_fpn6,
     #  rpn_roi_probs_fpn2, ..., rpn_roi_probs_fpn6]
     # If training with Faster R-CNN, then inputs will additionally include
     #  + [roidb, im_info]
     rois = self.collect(inputs, self._train)
     if self._train:
         # During training we reuse the data loader code. We populate roidb
         # entries on the fly using the rois generated by RPN.
         # im_info: [[im_height, im_width, im_scale], ...]
         im_info = inputs[-1].data
         im_scales = im_info[:, 2]
         roidb = blob_utils.deserialize(inputs[-2].data)
         # For historical consistency with the original Faster R-CNN
         # implementation we are *not* filtering crowd proposals.
         # This choice should be investigated in the future (it likely does
         # not matter).
         json_dataset.add_proposals(roidb, rois, im_scales, crowd_thresh=0)
         # Compute training labels for the RPN proposals; also handles
         # distributing the proposals over FPN levels
         output_blob_names = roi_data.fast_rcnn.get_fast_rcnn_blob_names()
         blobs = {k: [] for k in output_blob_names}
         roi_data.fast_rcnn.add_fast_rcnn_blobs(blobs, im_scales, roidb)
         for i, k in enumerate(output_blob_names):
             blob_utils.py_op_copy_blob(blobs[k], outputs[i])
     else:
         # For inference we have a special code path that avoids some data
         # loader overhead
         self.distribute(rois, None, outputs, self._train)
예제 #3
0
    def save_im_masks(self, blobs):
        import os, uuid
        from datasets.dataset_catalog import _DATA_DIR
        import utils.blob as blob_utils

        channel_swap = (0, 2, 3, 1)
        data = blobs['data'].copy()

        im = data.transpose(channel_swap)[0]
        im = self.rescale_0_1(im)

        roidb_temp = blob_utils.deserialize(blobs['roidb'])[0]

        im_name = str(self._counter) + '_' + os.path.splitext(os.path.basename(roidb_temp['image']))[0]

        with self._lock:
            self._counter += 1

        out_dir = os.path.join(_DATA_DIR, 'vis', roidb_temp['nuclei_class'])
        im_name += '_' + str(uuid.uuid4().get_hex().upper()[0:6])

        try:
            os.makedirs(out_dir)
        except:
            pass

        aug_rles = roidb_temp['segms']

        boxes = roidb_temp['boxes']
        boxes = np.append(boxes, np.ones((len(boxes), 2)), 1)
        im_scale = blobs['im_info'][0, 2]

        from utils.vis import vis_one_image
        vis_one_image(im, im_name, out_dir, boxes, segms=aug_rles, keypoints=None, thresh=0.7,
                      box_alpha=0.8, show_class=False, scale=im_scale)
예제 #4
0
    def _forward(self, data, im_info, roidb=None, flags=None, **rpn_kwargs):
        im_data = data
        return_dict = {}  # A dict to collect return variables

        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        # if training FAST-RCNN like
        # create rpn_ret at first (it is ensured that rois-data are numpy arrays and on CPU before
        # actual convolutional inputs are created to save memory

        if not cfg.RPN.RPN_ON:
            rpn_ret = create_fast_rcnn_rpn_ret(self.training, **rpn_kwargs)

        blob_conv, blob_conv_base = self.Conv_Body(im_data)

        if cfg.RPN.RPN_ON:
            rpn_ret = self.RPN(blob_conv, im_info, roidb, flags)

        return_dict['rpn_ret'] = rpn_ret

        blob_conv_pooled = self.roi_pool(blob_conv, rpn_ret)

        if cfg.DEBUG:
            print("\tShape ConvPooled: {}".format(blob_conv_pooled.size()))

        if not self.training:
            return_dict['blob_conv_pooled'] = blob_conv_pooled

        if not self.training or flags.fake_mode or flags.real_fake_mode:
            blob_conv_residual = self.Generator_Block(blob_conv_base, rpn_ret)

        if cfg.DEBUG and (not self.training or flags.fake_mode):
            print("\tShape Residual: {}".format(blob_conv_residual.size()))

        if not self.training:
            return_dict['blob_conv_residual'] = blob_conv_residual
            return_dict['blob_fake'] = blob_conv_pooled + blob_conv_residual

        if self.training:
            if flags.real_mode:
                return_dict['blob_conv'] = blob_conv_pooled
                if cfg.DEBUG:
                    print("\tblob_conv: blob_conv_pooled")
            elif flags.fake_mode or flags.real_fake_mode:
                return_dict[
                    'blob_conv'] = blob_conv_pooled + blob_conv_residual
                if cfg.DEBUG:
                    print("\tblob_conv: blob_conv_pooled + blob_conv_residual")
        else:
            if cfg.DEBUG_GAN:
                return_dict['blob_conv'] = blob_conv_pooled
            else:
                return_dict[
                    'blob_conv'] = blob_conv_pooled + blob_conv_residual

        return return_dict
예제 #5
0
    def _add_images(self, step, input_data, shot):
        """Plot the visualization results.
        Args:
            step (int): The number of the step.
            input_data (dict): The sample batch.
        """
        inv_normalize = transforms.Normalize(
            mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
            std=[1/0.229, 1/0.224, 1/0.255]
        )
        to_tensor = transforms.ToTensor()

        im_batched = input_data['data'][0].float()
        query_batched = input_data['query'][0]
        im_info_batched = input_data['im_info'][0].float()
        roidb_batched = list(map(lambda x: blob_utils.deserialize(x)[0], input_data['roidb'][0]))

        im_info = im_info_batched[0]
        im_scale = im_info.data.numpy()[2]

        gt_boxes = roidb_batched[0]['boxes'] * im_scale

        im = inv_normalize(im_batched[0]).permute(1, 2, 0).data.numpy()
        im = (im - im.max()) / (im.max() - im.min())
        im = (im *255).astype(np.uint8)
        im = Image.fromarray(im)
        
        querys = []
        for i in range(shot):
            query = inv_normalize(query_batched[i][0].float()).permute(1, 2, 0).data.numpy()
            query = (query - query.max()) / (query.max() - query.min())
            query = (query *255).astype(np.uint8)
            query = Image.fromarray(query)
            querys.append(to_tensor(query))
        
        querys_grid = make_grid(querys, nrow=shot//2, normalize=True, scale_each=True, pad_value=1)
        querys_grid = transforms.ToPILImage()(querys_grid).convert("RGB")
        query_w, query_h = querys_grid.size
        query_bg = Image.new('RGB', (im.size), (255, 255, 255))
        bg_w, bg_h = query_bg.size
        offset = ((bg_w - query_w) // 2, (bg_h - query_h) // 2)
        query_bg.paste(querys_grid, offset)

        im_gt_bbox = im.copy()
        for bbox in gt_boxes:
            if bbox.sum().item()==0:
                break
            bbox = tuple(list(map(int,bbox.tolist())))
            draw = ImageDraw.Draw(im_gt_bbox)
            draw.rectangle(bbox, fill=None, outline=(0, 110, 255), width=2)
        
        train_grid = [to_tensor(im), to_tensor(query_bg), to_tensor(im_gt_bbox)]
        train_grid = make_grid(train_grid, nrow=2, normalize=True, scale_each=True, pad_value=1)
        self.writer.add_image("logs/train", train_grid, step)
        
 def forward(self, inputs, outputs):
     # During training we reuse the data loader code. We populate roidb
     # entries on the fly using the rois generated by RPN.
     # im_info: [[im_height, im_width, im_scale], ...]
     rois = inputs[0].data
     roidb = blob_utils.deserialize(inputs[1].data)
     im_info = inputs[2].data
     im_scales = im_info[:, 2]
     output_blob_names = roi_data.fast_rcnn.get_fast_rcnn_blob_names()
     json_dataset.add_proposals(roidb, rois, im_scales)
     blobs = {k: [] for k in output_blob_names}
     roi_data.fast_rcnn.add_fast_rcnn_blobs(blobs, im_scales, roidb)
     for i, k in enumerate(output_blob_names):
         blob_utils.py_op_copy_blob(blobs[k], outputs[i])
예제 #7
0
 def forward(self, inputs, outputs):
     # During training we reuse the data loader code. We populate roidb
     # entries on the fly using the rois generated by RPN.
     # im_info: [[im_height, im_width, im_scale], ...]
     rois = inputs[0].data
     roidb = blob_utils.deserialize(inputs[1].data)
     im_info = inputs[2].data
     im_scales = im_info[:, 2]
     output_blob_names = roi_data.fast_rcnn.get_fast_rcnn_blob_names()
     json_dataset.add_proposals(roidb, rois, im_scales)
     blobs = {k: [] for k in output_blob_names}
     roi_data.fast_rcnn.add_fast_rcnn_blobs(blobs, im_scales, roidb)
     for i, k in enumerate(output_blob_names):
         blob_utils.py_op_copy_blob(blobs[k], outputs[i])
예제 #8
0
    def _forward(self, data, im_info, anchor_poses=None, roidb=None, **rpn_kwargs):

        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            #cls_score, bbox_pred = self.Box_Outs(box_feat)
            cls_score, pose_pred = self.Pose_Outs(box_feat)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        assert not self.training, "Only implemented for testing"

        # Testing
        return_dict['rois'] = rpn_ret['rois']
        return_dict['cls_score'] = cls_score
        #return_dict['bbox_pred'] = bbox_pred
        return_dict['pose_pred'] = pose_pred
              
        return return_dict
 def forward(self, inputs, outputs):
     """See modeling.detector.GenerateProposalLabels for inputs/outputs
     documentation.
     """
     # During training we reuse the data loader code. We populate roidb
     # entries on the fly using the rois generated by RPN.
     # im_info: [[im_height, im_width, im_scale], ...]
     rois = inputs[0].data
     roidb = blob_utils.deserialize(inputs[1].data)
     im_info = inputs[2].data
     im_scales = im_info[:, 2]
     output_blob_names = roi_data.fast_rcnn.get_fast_rcnn_blob_names()
     # For historical consistency with the original Faster R-CNN
     # implementation we are *not* filtering crowd proposals.
     # This choice should be investigated in the future (it likely does
     # not matter).
     json_dataset.add_proposals(roidb, rois, im_scales, crowd_thresh=0)
     blobs = {k: [] for k in output_blob_names}
     roi_data.fast_rcnn.add_fast_rcnn_blobs(blobs, im_scales, roidb)
     for i, k in enumerate(output_blob_names):
         blob_utils.py_op_copy_blob(blobs[k], outputs[i])
예제 #10
0
 def forward(self, inputs, outputs):
     """See modeling.detector.GenerateProposalLabels for inputs/outputs
     documentation.
     """
     # During training we reuse the data loader code. We populate roidb
     # entries on the fly using the rois generated by RPN.
     # im_info: [[im_height, im_width, im_scale], ...]
     rois = inputs[0].data
     roidb = blob_utils.deserialize(inputs[1].data)
     im_info = inputs[2].data
     im_scales = im_info[:, 2]
     output_blob_names = roi_data.fast_rcnn.get_fast_rcnn_blob_names()
     # For historical consistency with the original Faster R-CNN
     # implementation we are *not* filtering crowd proposals.
     # This choice should be investigated in the future (it likely does
     # not matter).
     json_dataset.add_proposals(roidb, rois, im_scales, crowd_thresh=0)
     blobs = {k: [] for k in output_blob_names}
     roi_data.fast_rcnn.add_fast_rcnn_blobs(blobs, im_scales, roidb)
     for i, k in enumerate(output_blob_names):
         blob_utils.py_op_copy_blob(blobs[k], outputs[i])
예제 #11
0
 def forward(self, inputs, outputs):
     # inputs is
     # [rpn_rois_fpn2, ..., rpn_rois_fpn6,
     #  rpn_roi_probs_fpn2, ..., rpn_roi_probs_fpn6]
     # If training with Faster R-CNN, then inputs will additionally include
     #  + [roidb, im_info]
     rois = collect(inputs, self._train)
     if self._train:
         # During training we reuse the data loader code. We populate roidb
         # entries on the fly using the rois generated by RPN.
         # im_info: [[im_height, im_width, im_scale], ...]
         im_info = inputs[-1].data
         im_scales = im_info[:, 2]
         roidb = blob_utils.deserialize(inputs[-2].data)
         output_blob_names = roi_data.fast_rcnn.get_fast_rcnn_blob_names()
         json_dataset.add_proposals(roidb, rois, im_scales)
         blobs = {k: [] for k in output_blob_names}
         roi_data.fast_rcnn.add_fast_rcnn_blobs(blobs, im_scales, roidb)
         for i, k in enumerate(output_blob_names):
             blob_utils.py_op_copy_blob(blobs[k], outputs[i])
     else:
         # For inference we have a special code path that avoids some data
         # loader overhead
         distribute(rois, None, outputs, self._train)
예제 #12
0
    def _forward(self,
                 data,
                 im_info,
                 dataset_name=None,
                 roidb=None,
                 use_gt_labels=False,
                 include_feat=False,
                 **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))
        if dataset_name is not None:
            dataset_name = blob_utils.deserialize(dataset_name)
        else:
            dataset_name = cfg.TRAIN.DATASETS[
                0] if self.training else cfg.TEST.DATASETS[
                    0]  # assuming only one dataset per run

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)
        blob_conv_prd = self.Prd_RCNN.Conv_Body(im_data)

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]
            blob_conv_prd = blob_conv_prd[-self.num_roi_levels:]

        if not cfg.TRAIN.USE_GT_BOXES:
            rpn_ret = self.RPN(blob_conv, im_info, roidb)

            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv,
                                                    rpn_ret,
                                                    use_relu=True)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret, use_relu=True)
            cls_score, bbox_pred = self.Box_Outs(box_feat)

        # now go through the predicate branch
        use_relu = False if cfg.MODEL.NO_FC7_RELU else True
        if self.training:
            if cfg.TRAIN.USE_GT_BOXES:
                # we always feed one image per batch during training
                assert len(roidb) == 1
                im_scale = im_info.data.numpy()[:, 2][0]
                im_w = im_info.data.numpy()[:, 1][0]
                im_h = im_info.data.numpy()[:, 0][0]
                sbj_boxes = roidb[0]['sbj_gt_boxes']
                obj_boxes = roidb[0]['obj_gt_boxes']
                sbj_all_boxes = _augment_gt_boxes_by_perturbation(
                    sbj_boxes, im_w, im_h)
                obj_all_boxes = _augment_gt_boxes_by_perturbation(
                    obj_boxes, im_w, im_h)
                det_all_boxes = np.vstack((sbj_all_boxes, obj_all_boxes))
                det_all_boxes = np.unique(det_all_boxes, axis=0)
                det_all_rois = det_all_boxes * im_scale
                repeated_batch_idx = 0 * blob_utils.ones(
                    (det_all_rois.shape[0], 1))
                det_all_rois = np.hstack((repeated_batch_idx, det_all_rois))
                rel_ret = self.RelPN(det_all_rois, None, None, im_info,
                                     dataset_name, roidb)
            else:
                fg_inds = np.where(rpn_ret['labels_int32'] > 0)[0]
                det_rois = rpn_ret['rois'][fg_inds]
                det_labels = rpn_ret['labels_int32'][fg_inds]
                det_scores = F.softmax(cls_score[fg_inds], dim=1)
                rel_ret = self.RelPN(det_rois, det_labels, det_scores, im_info,
                                     dataset_name, roidb)
            sbj_feat = self.Box_Head(blob_conv,
                                     rel_ret,
                                     rois_name='sbj_rois',
                                     use_relu=use_relu)
            obj_feat = self.Box_Head(blob_conv,
                                     rel_ret,
                                     rois_name='obj_rois',
                                     use_relu=use_relu)
        else:
            if roidb is not None:
                im_scale = im_info.data.numpy()[:, 2][0]
                im_w = im_info.data.numpy()[:, 1][0]
                im_h = im_info.data.numpy()[:, 0][0]
                sbj_boxes = roidb['sbj_gt_boxes']
                obj_boxes = roidb['obj_gt_boxes']
                sbj_rois = sbj_boxes * im_scale
                obj_rois = obj_boxes * im_scale
                repeated_batch_idx = 0 * blob_utils.ones(
                    (sbj_rois.shape[0], 1))
                sbj_rois = np.hstack((repeated_batch_idx, sbj_rois))
                obj_rois = np.hstack((repeated_batch_idx, obj_rois))
                rel_rois = box_utils.rois_union(sbj_rois, obj_rois)
                rel_ret = {}
                rel_ret['sbj_rois'] = sbj_rois
                rel_ret['obj_rois'] = obj_rois
                rel_ret['rel_rois'] = rel_rois
                if cfg.FPN.FPN_ON and cfg.FPN.MULTILEVEL_ROIS:
                    lvl_min = cfg.FPN.ROI_MIN_LEVEL
                    lvl_max = cfg.FPN.ROI_MAX_LEVEL
                    rois_blob_names = ['sbj_rois', 'obj_rois', 'rel_rois']
                    for rois_blob_name in rois_blob_names:
                        # Add per FPN level roi blobs named like: <rois_blob_name>_fpn<lvl>
                        target_lvls = fpn_utils.map_rois_to_fpn_levels(
                            rel_ret[rois_blob_name][:, 1:5], lvl_min, lvl_max)
                        fpn_utils.add_multilevel_roi_blobs(
                            rel_ret, rois_blob_name, rel_ret[rois_blob_name],
                            target_lvls, lvl_min, lvl_max)
                if use_gt_labels:
                    sbj_labels = roidb['sbj_gt_classes']  # start from 0
                    obj_labels = roidb['obj_gt_classes']  # start from 0
                    sbj_scores = np.ones_like(sbj_labels, dtype=np.float32)
                    obj_scores = np.ones_like(obj_labels, dtype=np.float32)
                else:
                    sbj_det_feat = self.Box_Head(blob_conv,
                                                 rel_ret,
                                                 rois_name='sbj_rois',
                                                 use_relu=True)
                    sbj_cls_scores, _ = self.Box_Outs(sbj_det_feat)
                    sbj_cls_scores = sbj_cls_scores.data.cpu().numpy()
                    obj_det_feat = self.Box_Head(blob_conv,
                                                 rel_ret,
                                                 rois_name='obj_rois',
                                                 use_relu=True)
                    obj_cls_scores, _ = self.Box_Outs(obj_det_feat)
                    obj_cls_scores = obj_cls_scores.data.cpu().numpy()
                    sbj_labels = np.argmax(sbj_cls_scores[:, 1:], axis=1)
                    obj_labels = np.argmax(obj_cls_scores[:, 1:], axis=1)
                    sbj_scores = np.amax(sbj_cls_scores[:, 1:], axis=1)
                    obj_scores = np.amax(obj_cls_scores[:, 1:], axis=1)
                rel_ret['sbj_scores'] = sbj_scores.astype(np.float32,
                                                          copy=False)
                rel_ret['obj_scores'] = obj_scores.astype(np.float32,
                                                          copy=False)
                rel_ret['sbj_labels'] = sbj_labels.astype(
                    np.int32, copy=False) + 1  # need to start from 1
                rel_ret['obj_labels'] = obj_labels.astype(
                    np.int32, copy=False) + 1  # need to start from 1
                rel_ret['all_sbj_labels_int32'] = sbj_labels.astype(np.int32,
                                                                    copy=False)
                rel_ret['all_obj_labels_int32'] = obj_labels.astype(np.int32,
                                                                    copy=False)
                sbj_feat = self.Box_Head(blob_conv,
                                         rel_ret,
                                         rois_name='sbj_rois',
                                         use_relu=use_relu)
                obj_feat = self.Box_Head(blob_conv,
                                         rel_ret,
                                         rois_name='obj_rois',
                                         use_relu=use_relu)
            else:
                score_thresh = cfg.TEST.SCORE_THRESH
                while score_thresh >= -1e-06:  # a negative value very close to 0.0
                    det_rois, det_labels, det_scores = \
                        self.prepare_det_rois(rpn_ret['rois'], cls_score, bbox_pred, im_info, score_thresh)
                    rel_ret = self.RelPN(det_rois, det_labels, det_scores,
                                         im_info, dataset_name, roidb)
                    valid_len = len(rel_ret['rel_rois'])
                    if valid_len > 0:
                        break
                    logger.info(
                        'Got {} rel_rois when score_thresh={}, changing to {}'.
                        format(valid_len, score_thresh, score_thresh - 0.01))
                    score_thresh -= 0.01
                det_feat = self.Box_Head(blob_conv,
                                         rel_ret,
                                         rois_name='det_rois',
                                         use_relu=use_relu)
                sbj_feat = det_feat[rel_ret['sbj_inds']]
                obj_feat = det_feat[rel_ret['obj_inds']]

        rel_feat = self.Prd_RCNN.Box_Head(blob_conv_prd,
                                          rel_ret,
                                          rois_name='rel_rois',
                                          use_relu=use_relu)

        concat_feat = torch.cat((sbj_feat, rel_feat, obj_feat), dim=1)

        if cfg.MODEL.USE_FREQ_BIAS or cfg.MODEL.RUN_BASELINE or cfg.MODEL.USE_SEM_CONCAT:
            sbj_labels = rel_ret['all_sbj_labels_int32']
            obj_labels = rel_ret['all_obj_labels_int32']
        else:
            sbj_labels = None
            obj_labels = None

        # when MODEL.USE_SEM_CONCAT, memory runs out if the whole batch is fed once
        # so we need to feed the batch twice if it's big
        gn_size = 1000
        if cfg.MODEL.USE_SEM_CONCAT and concat_feat.shape[0] > gn_size:
            group = int(math.ceil(concat_feat.shape[0] / gn_size))
            prd_cls_scores = None
            sbj_cls_scores = None
            obj_cls_scores = None
            for i in range(group):
                end = int(min((i + 1) * gn_size, concat_feat.shape[0]))
                concat_feat_i = concat_feat[i * gn_size:end]
                sbj_labels_i = sbj_labels[
                    i * gn_size:end] if sbj_labels is not None else None
                obj_labels_i = obj_labels[
                    i * gn_size:end] if obj_labels is not None else None
                sbj_feat_i = sbj_feat[i * gn_size:end]
                obj_feat_i = obj_feat[i * gn_size:end]
                prd_cls_scores_i, sbj_cls_scores_i, obj_cls_scores_i = \
                    self.RelDN(concat_feat_i, sbj_labels_i, obj_labels_i, sbj_feat_i, obj_feat_i)
                if prd_cls_scores is None:
                    prd_cls_scores = prd_cls_scores_i
                    sbj_cls_scores = sbj_cls_scores_i
                    obj_cls_scores = obj_cls_scores_i
                else:
                    prd_cls_scores = torch.cat(
                        (prd_cls_scores, prd_cls_scores_i))
                    sbj_cls_scores = torch.cat(
                        (sbj_cls_scores, sbj_cls_scores_i
                         )) if sbj_cls_scores_i is not None else sbj_cls_scores
                    obj_cls_scores = torch.cat(
                        (obj_cls_scores, obj_cls_scores_i
                         )) if obj_cls_scores_i is not None else obj_cls_scores
        else:
            prd_cls_scores, sbj_cls_scores, obj_cls_scores = \
                    self.RelDN(concat_feat, sbj_labels, obj_labels, sbj_feat, obj_feat)

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            if not cfg.TRAIN.USE_GT_BOXES:
                # rpn loss
                rpn_kwargs.update(
                    dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                         if (k.startswith('rpn_cls_logits')
                             or k.startswith('rpn_bbox_pred'))))
                loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                    **rpn_kwargs)
                if cfg.FPN.FPN_ON:
                    for i, lvl in enumerate(
                            range(cfg.FPN.RPN_MIN_LEVEL,
                                  cfg.FPN.RPN_MAX_LEVEL + 1)):
                        return_dict['losses']['loss_rpn_cls_fpn%d' %
                                              lvl] = loss_rpn_cls[i]
                        return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                              lvl] = loss_rpn_bbox[i]
                else:
                    return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                    return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox
                # bbox loss
                loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                    cls_score, bbox_pred, rpn_ret['labels_int32'],
                    rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                    rpn_ret['bbox_outside_weights'])
                return_dict['losses']['loss_cls'] = loss_cls
                return_dict['losses']['loss_bbox'] = loss_bbox
                return_dict['metrics']['accuracy_cls'] = accuracy_cls
            loss_cls_prd, accuracy_cls_prd = reldn_heads.reldn_losses(
                prd_cls_scores,
                rel_ret['all_prd_labels_int32'],
                weight=self.prd_weights)
            return_dict['losses']['loss_cls_prd'] = loss_cls_prd
            return_dict['metrics']['accuracy_cls_prd'] = accuracy_cls_prd
            if cfg.MODEL.USE_SEPARATE_SO_SCORES:
                loss_cls_sbj, accuracy_cls_sbj = reldn_heads.reldn_losses(
                    sbj_cls_scores,
                    rel_ret['all_sbj_labels_int32'],
                    weight=self.obj_weights)
                return_dict['losses']['loss_cls_sbj'] = loss_cls_sbj
                return_dict['metrics']['accuracy_cls_sbj'] = accuracy_cls_sbj
                loss_cls_obj, accuracy_cls_obj = reldn_heads.reldn_losses(
                    obj_cls_scores,
                    rel_ret['all_obj_labels_int32'],
                    weight=self.obj_weights)
                return_dict['losses']['loss_cls_obj'] = loss_cls_obj
                return_dict['metrics']['accuracy_cls_obj'] = accuracy_cls_obj

            if cfg.TRAIN.HUBNESS:
                loss_hubness_prd = reldn_heads.add_hubness_loss(prd_cls_scores)
                loss_hubness_sbj = reldn_heads.add_hubness_loss(sbj_cls_scores)
                loss_hubness_obj = reldn_heads.add_hubness_loss(obj_cls_scores)
                return_dict['losses']['loss_hubness_prd'] = loss_hubness_prd
                return_dict['losses']['loss_hubness_sbj'] = loss_hubness_sbj
                return_dict['losses']['loss_hubness_obj'] = loss_hubness_obj

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)
        else:
            # Testing
            return_dict['sbj_rois'] = rel_ret['sbj_rois']
            return_dict['obj_rois'] = rel_ret['obj_rois']
            return_dict['sbj_labels'] = rel_ret['sbj_labels']
            return_dict['obj_labels'] = rel_ret['obj_labels']
            return_dict['sbj_scores'] = rel_ret['sbj_scores']
            return_dict['sbj_scores_out'] = sbj_cls_scores
            return_dict['obj_scores'] = rel_ret['obj_scores']
            return_dict['obj_scores_out'] = obj_cls_scores
            return_dict['prd_scores'] = prd_cls_scores
            if include_feat:
                return_dict['sbj_feat'] = sbj_feat
                return_dict['obj_feat'] = obj_feat
                return_dict['prd_feat'] = concat_feat

        return return_dict
예제 #13
0
    def _forward(self, data, support_data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))
        device_id = im_data.get_device()
        return_dict = {}  # A dict to collect return variables
        original_blob_conv = self.Conv_Body(im_data)

        if not self.training:
            all_cls = roidb[0]['support_cls']
            test_way = len(all_cls)
            test_shot = int(roidb[0]['support_boxes'].shape[0] / test_way)
        else:
            train_way = 2  #2
            train_shot = 5  #5
        support_blob_conv = self.Conv_Body(support_data.squeeze(0))

        img_num = int(support_blob_conv.shape[0])
        img_channel = int(original_blob_conv.shape[1])
        # Construct support rpn_ret
        support_rpn_ret = {
            'rois': np.insert(roidb[0]['support_boxes'][0], 0,
                              0.)[np.newaxis, :]
        }
        if img_num > 1:
            for i in range(img_num - 1):
                support_rpn_ret['rois'] = np.concatenate(
                    (support_rpn_ret['rois'],
                     np.insert(roidb[0]['support_boxes'][i + 1], 0,
                               float(i + 1))[np.newaxis, :]),
                    axis=0)
        # Get support pooled feature
        support_feature = self.roi_feature_transform(
            support_blob_conv,
            support_rpn_ret,
            blob_rois='rois',
            method=cfg.FAST_RCNN.ROI_XFORM_METHOD,
            resolution=cfg.FAST_RCNN.ROI_XFORM_RESOLUTION,
            spatial_scale=self.Conv_Body.spatial_scale,
            sampling_ratio=cfg.FAST_RCNN.ROI_XFORM_SAMPLING_RATIO)

        blob_conv = original_blob_conv
        rpn_ret = self.RPN(blob_conv, im_info, roidb)
        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        assert not cfg.MODEL.RPN_ONLY

        support_box_feat = self.Box_Head(support_blob_conv, support_rpn_ret)
        if self.training:
            support_feature_mean_0 = support_feature[:5].mean(0, True)
            support_pool_feature_0 = self.avgpool(support_feature_mean_0)
            correlation_0 = F.conv2d(original_blob_conv,
                                     support_pool_feature_0.permute(
                                         1, 0, 2, 3),
                                     groups=1024)
            rpn_ret_0 = self.RPN(correlation_0, im_info, roidb)
            box_feat_0 = self.Box_Head(blob_conv, rpn_ret_0)
            support_0 = support_box_feat[:5].mean(
                0, True)  # simple average few shot support features

            support_feature_mean_1 = support_feature[5:10].mean(0, True)
            support_pool_feature_1 = self.avgpool(support_feature_mean_1)
            correlation_1 = F.conv2d(original_blob_conv,
                                     support_pool_feature_1.permute(
                                         1, 0, 2, 3),
                                     groups=1024)
            rpn_ret_1 = self.RPN(correlation_1, im_info, roidb)
            box_feat_1 = self.Box_Head(blob_conv, rpn_ret_1)
            support_1 = support_box_feat[5:10].mean(
                0, True)  # simple average few shot support features

            cls_score_now_0, bbox_pred_now_0 = self.Box_Outs(
                box_feat_0, support_0)
            cls_score_now_1, bbox_pred_now_1 = self.Box_Outs(
                box_feat_1, support_1)

            cls_score = torch.cat([cls_score_now_0, cls_score_now_1], dim=0)
            bbox_pred = torch.cat([bbox_pred_now_0, bbox_pred_now_1], dim=0)
            rpn_ret = {}
            for key in rpn_ret_0.keys():
                if key != 'rpn_cls_logits' and key != 'rpn_bbox_pred':
                    rpn_ret[key] = rpn_ret_0[
                        key]  #np.concatenate((rpn_ret_0[key], rpn_ret_1[key]), axis=0)
                else:
                    rpn_ret[key] = torch.cat([rpn_ret_0[key], rpn_ret_1[key]],
                                             dim=0)

        else:
            for way_id in range(test_way):
                begin = way_id * test_shot
                end = (way_id + 1) * test_shot

                support_feature_mean = support_feature[begin:end].mean(0, True)
                support_pool_feature = self.avgpool(support_feature_mean)
                correlation = F.conv2d(original_blob_conv,
                                       support_pool_feature.permute(
                                           1, 0, 2, 3),
                                       groups=1024)
                rpn_ret = self.RPN(correlation, im_info, roidb)

                if not cfg.MODEL.RPN_ONLY:
                    if cfg.MODEL.SHARE_RES5 and self.training:
                        box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
                    else:
                        box_feat = self.Box_Head(blob_conv, rpn_ret)

                    support = support_box_feat[begin:end].mean(
                        0, True)  # simple average few shot support features

                    cls_score_now, bbox_pred_now = self.Box_Outs(
                        box_feat, support)
                    cls_now = cls_score_now.new_full(
                        (cls_score_now.shape[0], 1),
                        int(roidb[0]['support_cls'][way_id]))
                    cls_score_now = torch.cat([cls_score_now, cls_now], dim=1)
                    rois = rpn_ret['rois']
                    if way_id == 0:
                        cls_score = cls_score_now
                        bbox_pred = bbox_pred_now
                        rois_all = rois
                    else:
                        cls_score = torch.cat([cls_score, cls_score_now],
                                              dim=0)
                        bbox_pred = torch.cat([bbox_pred, bbox_pred_now],
                                              dim=0)
                        rois_all = np.concatenate((rois_all, rois), axis=0)
                else:
                    # TODO: complete the returns for RPN only situation
                    pass
            rpn_ret['rois'] = rois_all
        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            target_cls = roidb[0]['target_cls']
            rpn_ret['labels_obj'] = np.array(
                [int(i == target_cls) for i in rpn_ret['labels_int32']])

            # filter other class bbox targets, only supervise the target cls bbox, because the bbox_pred is from the compare feature.
            bg_idx = np.where(rpn_ret['labels_int32'] != target_cls)[0]
            rpn_ret['bbox_targets'][bg_idx] = np.full_like(
                rpn_ret['bbox_targets'][bg_idx], 0.)
            rpn_ret['bbox_inside_weights'][bg_idx] = np.full_like(
                rpn_ret['bbox_inside_weights'][bg_idx], 0.)
            rpn_ret['bbox_outside_weights'][bg_idx] = np.full_like(
                rpn_ret['bbox_outside_weights'][bg_idx], 0.)

            neg_labels_obj = np.full_like(rpn_ret['labels_obj'], 0.)
            neg_bbox_targets = np.full_like(rpn_ret['bbox_targets'], 0.)
            neg_bbox_inside_weights = np.full_like(
                rpn_ret['bbox_inside_weights'], 0.)
            neg_bbox_outside_weights = np.full_like(
                rpn_ret['bbox_outside_weights'], 0.)

            rpn_ret['labels_obj'] = np.concatenate(
                [rpn_ret['labels_obj'], neg_labels_obj], axis=0)
            rpn_ret['bbox_targets'] = np.concatenate(
                [rpn_ret['bbox_targets'], neg_bbox_targets], axis=0)
            rpn_ret['bbox_inside_weights'] = np.concatenate(
                [rpn_ret['bbox_inside_weights'], neg_bbox_inside_weights],
                axis=0)
            rpn_ret['bbox_outside_weights'] = np.concatenate(
                [rpn_ret['bbox_outside_weights'], neg_bbox_outside_weights],
                axis=0)

            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_obj'],
                rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = 1 * loss_cls
            return_dict['losses']['loss_bbox'] = 1 * loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                    mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_keypoints_int32=rpn_ret[
                            'roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        return return_dict
예제 #14
0
def run_eval(args,
             cfg,
             maskRCNN,
             dataloader_test,
             step,
             output_dir,
             test_stats,
             best_eval_result,
             eval_subset='test'):

    is_best = False
    maskRCNN.eval()

    print(
        '------------------------------------------------------------------------------------------------------------'
    )
    print('eval %s: %s' % (eval_subset, 'mAP' if cfg.EVAL_MAP else 'recall'))
    print(
        '------------------------------------------------------------------------------------------------------------'
    )

    ### -------------------------------------------------------------------------------------------------------------------
    # get results
    ### -------------------------------------------------------------------------------------------------------------------
    file_names = {}
    boxes = {}
    scores = {}
    scores_ori = {}

    human_file_names = {}
    human_boxes = {}
    human_scores = {}
    human_scores_ori = {}

    obj_gt_cls_names = {}
    prd_gt_cls_names = {}

    obj_gt_cls_ids = {}
    prd_gt_cls_ids = {}

    if cfg.BINARY_LOSS:
        binary_preds = {}

    for i, input_data in enumerate(dataloader_test):
        for key in input_data:
            if key != 'roidb':
                input_data[key] = list(map(Variable, input_data[key]))

        if len(input_data['im_info']) != cfg.NUM_GPUS:
            print(len(input_data['im_info']))

        net_outputs_dict = maskRCNN(**input_data)

        for triplet_name in net_outputs_dict.keys():
            net_outputs = deepcopy(net_outputs_dict[triplet_name])

            for gpu_i in range(cfg.NUM_GPUS):
                if triplet_name not in boxes:
                    boxes[triplet_name] = []
                    scores[triplet_name] = []
                    scores_ori[triplet_name] = []

                    human_boxes[triplet_name] = []
                    human_scores[triplet_name] = []
                    human_scores_ori[triplet_name] = []

                    obj_gt_cls_names[triplet_name] = []
                    prd_gt_cls_names[triplet_name] = []

                    obj_gt_cls_ids[triplet_name] = []
                    prd_gt_cls_ids[triplet_name] = []

                    file_names[triplet_name] = []
                    human_file_names[triplet_name] = []

                    if cfg.BINARY_LOSS:
                        binary_preds[triplet_name] = []

                boxes[triplet_name] += [
                    box[(gpu_i) * cfg.TRAIN.BATCH_SIZE_PER_IM:(gpu_i + 1) *
                        cfg.TRAIN.BATCH_SIZE_PER_IM, :]
                    for box in net_outputs['predictions']['box']
                ]
                scores[triplet_name] += [
                    score[(gpu_i) * cfg.TRAIN.BATCH_SIZE_PER_IM:(gpu_i + 1) *
                          cfg.TRAIN.BATCH_SIZE_PER_IM]
                    for score in net_outputs['predictions']['score']
                ]
                scores_ori[triplet_name] += [
                    score_ori[(gpu_i) *
                              cfg.TRAIN.BATCH_SIZE_PER_IM:(gpu_i + 1) *
                              cfg.TRAIN.BATCH_SIZE_PER_IM]
                    for score_ori in net_outputs['predictions']['score_ori']
                ]

                assert len(net_outputs['predictions']['box']
                           [0]) == cfg.TRAIN.BATCH_SIZE_PER_IM * cfg.NUM_GPUS
                assert len(net_outputs['predictions']['score']
                           [0]) == cfg.TRAIN.BATCH_SIZE_PER_IM * cfg.NUM_GPUS
                assert len(net_outputs['predictions']['score_ori']
                           [0]) == cfg.TRAIN.BATCH_SIZE_PER_IM * cfg.NUM_GPUS

                file_name = blob_utils.deserialize(
                    net_outputs['predictions']['files'].numpy())
                obj_gt_cls_name = blob_utils.deserialize(
                    net_outputs['predictions']['obj_gt_cls_name'].numpy())
                prd_gt_cls_name = blob_utils.deserialize(
                    net_outputs['predictions']['prd_gt_cls_name'].numpy())
                obj_gt_cls = blob_utils.deserialize(
                    net_outputs['predictions']['obj_gt_cls'].numpy())
                prd_gt_cls = blob_utils.deserialize(
                    net_outputs['predictions']['prd_gt_cls'].numpy())

                file_names[triplet_name] += file_name
                obj_gt_cls_names[triplet_name] += obj_gt_cls_name
                prd_gt_cls_names[triplet_name] += prd_gt_cls_name
                obj_gt_cls_ids[triplet_name] += obj_gt_cls
                prd_gt_cls_ids[triplet_name] += prd_gt_cls

                len_gpu_i = len(
                    blob_utils.serialize(
                        blob_utils.deserialize(
                            net_outputs['predictions']['files'].numpy())))
                net_outputs['predictions']['files'] = net_outputs[
                    'predictions']['files'][len_gpu_i:]
                len_gpu_i = len(
                    blob_utils.serialize(
                        blob_utils.deserialize(net_outputs['predictions']
                                               ['obj_gt_cls_name'].numpy())))
                net_outputs['predictions']['obj_gt_cls_name'] = net_outputs[
                    'predictions']['obj_gt_cls_name'][len_gpu_i:]
                len_gpu_i = len(
                    blob_utils.serialize(
                        blob_utils.deserialize(net_outputs['predictions']
                                               ['prd_gt_cls_name'].numpy())))
                net_outputs['predictions']['prd_gt_cls_name'] = net_outputs[
                    'predictions']['prd_gt_cls_name'][len_gpu_i:]
                len_gpu_i = len(
                    blob_utils.serialize(
                        blob_utils.deserialize(
                            net_outputs['predictions']['obj_gt_cls'].numpy())))
                net_outputs['predictions']['obj_gt_cls'] = net_outputs[
                    'predictions']['obj_gt_cls'][len_gpu_i:]
                len_gpu_i = len(
                    blob_utils.serialize(
                        blob_utils.deserialize(
                            net_outputs['predictions']['prd_gt_cls'].numpy())))
                net_outputs['predictions']['prd_gt_cls'] = net_outputs[
                    'predictions']['prd_gt_cls'][len_gpu_i:]

                if cfg.BINARY_LOSS:
                    binary_preds[triplet_name] += [
                        binary_pred[(gpu_i) * 2:(gpu_i + 1) * 2]
                        for binary_pred in net_outputs['predictions']
                        ['binary_pred']
                    ]

                # human
                num_roi = cfg.MAX_NUM_HUMAN
                human_boxes[triplet_name] += [
                    box[(gpu_i) * num_roi:(gpu_i + 1) * num_roi, :]
                    for box in net_outputs['human_predictions']['box']
                ]
                human_scores[triplet_name] += [
                    score[(gpu_i) * num_roi:(gpu_i + 1) * num_roi]
                    for score in net_outputs['human_predictions']['score']
                ]
                human_scores_ori[triplet_name] += [
                    score_ori[(gpu_i) * num_roi:(gpu_i + 1) * num_roi] for
                    score_ori in net_outputs['human_predictions']['score_ori']
                ]

                human_file_name = blob_utils.deserialize(
                    net_outputs['human_predictions']['files'].numpy())
                human_obj_gt_cls_name = blob_utils.deserialize(
                    net_outputs['human_predictions']
                    ['obj_gt_cls_name'].numpy())
                human_prd_gt_cls_name = blob_utils.deserialize(
                    net_outputs['human_predictions']
                    ['prd_gt_cls_name'].numpy())
                obj_gt_cls = blob_utils.deserialize(
                    net_outputs['human_predictions']['obj_gt_cls'].numpy())
                prd_gt_cls = blob_utils.deserialize(
                    net_outputs['human_predictions']['prd_gt_cls'].numpy())
                human_file_names[triplet_name] += human_file_name

                len_gpu_i = len(
                    blob_utils.serialize(
                        blob_utils.deserialize(net_outputs['human_predictions']
                                               ['files'].numpy())))
                net_outputs['human_predictions']['files'] = net_outputs[
                    'human_predictions']['files'][len_gpu_i:]
                len_gpu_i = len(
                    blob_utils.serialize(
                        blob_utils.deserialize(net_outputs['human_predictions']
                                               ['obj_gt_cls_name'].numpy())))
                net_outputs['human_predictions'][
                    'obj_gt_cls_name'] = net_outputs['human_predictions'][
                        'obj_gt_cls_name'][len_gpu_i:]
                len_gpu_i = len(
                    blob_utils.serialize(
                        blob_utils.deserialize(net_outputs['human_predictions']
                                               ['prd_gt_cls_name'].numpy())))
                net_outputs['human_predictions'][
                    'prd_gt_cls_name'] = net_outputs['human_predictions'][
                        'prd_gt_cls_name'][len_gpu_i:]
                len_gpu_i = len(
                    blob_utils.serialize(
                        blob_utils.deserialize(net_outputs['human_predictions']
                                               ['obj_gt_cls'].numpy())))
                net_outputs['human_predictions']['obj_gt_cls'] = net_outputs[
                    'human_predictions']['obj_gt_cls'][len_gpu_i:]
                len_gpu_i = len(
                    blob_utils.serialize(
                        blob_utils.deserialize(net_outputs['human_predictions']
                                               ['prd_gt_cls'].numpy())))
                net_outputs['human_predictions']['prd_gt_cls'] = net_outputs[
                    'human_predictions']['prd_gt_cls'][len_gpu_i:]

                assert file_name == human_file_name
                assert obj_gt_cls_name == human_obj_gt_cls_name
                assert prd_gt_cls_name == human_prd_gt_cls_name

            assert len(scores[triplet_name]) == len(
                scores_ori[triplet_name]) == len(boxes[triplet_name]) == len(
                    file_names[triplet_name])
            assert len(human_scores[triplet_name]) == len(
                human_boxes[triplet_name]) == len(
                    human_file_names[triplet_name])
            assert len(file_names[triplet_name]) == len(
                obj_gt_cls_names[triplet_name]) == len(
                    prd_gt_cls_names[triplet_name])

    predictions_all_triplet = {}
    human_predictions_all_triplet = {}

    for triplet_name in net_outputs_dict.keys():
        predictions = {}
        for i, file_name in enumerate(file_names[triplet_name]):
            predictions[file_name] = {}
            predictions[file_name]['boxes'] = boxes[triplet_name][i]
            predictions[file_name]['scores'] = scores[triplet_name][i]
            predictions[file_name]['scores_ori'] = scores_ori[triplet_name][i]
            predictions[file_name]['obj_gt_cls_names'] = obj_gt_cls_names[
                triplet_name][i]
            predictions[file_name]['prd_gt_cls_names'] = prd_gt_cls_names[
                triplet_name][i]
            predictions[file_name]['obj_gt_cls_ids'] = obj_gt_cls_ids[
                triplet_name][i]
            predictions[file_name]['prd_gt_cls_ids'] = prd_gt_cls_ids[
                triplet_name][i]
            if cfg.BINARY_LOSS:
                predictions[file_name]['binary_preds'] = binary_preds[
                    triplet_name][i]
        predictions_all_triplet[triplet_name] = predictions

        # human
        human_predictions = {}
        for i, file_name in enumerate(human_file_names[triplet_name]):
            human_predictions[file_name] = {}
            human_predictions[file_name]['boxes'] = human_boxes[triplet_name][
                i]
            human_predictions[file_name]['scores'] = human_scores[
                triplet_name][i]
            human_predictions[file_name]['scores_ori'] = human_scores_ori[
                triplet_name][i]
        human_predictions_all_triplet[triplet_name] = human_predictions

    eval_input = {}
    eval_input['predictions_object_bbox'] = predictions_all_triplet
    eval_input['predictions_human_bbox'] = human_predictions_all_triplet
    eval_input[
        'video_name_triplet_dict'] = maskRCNN.module.video_name_triplet_dict

    # ------------------------------------------------------------------------------------------------------------
    # Compute Recall and mAP
    # ------------------------------------------------------------------------------------------------------------
    if 'vhico' in args.dataset:
        if not cfg.EVAL_MAP:
            frame_recall_phrase_ko = vhico_eval(cfg,
                                                eval_subset=eval_subset,
                                                eval_input=eval_input,
                                                GT_PATH_TEST=GT_PATH_TEST,
                                                GT_PATH_UNSEEN=GT_PATH_UNSEEN)

            test_stats.tb_log_stats(
                {'frame_recall_phrase_ko': frame_recall_phrase_ko}, step)
            if frame_recall_phrase_ko > best_eval_result:
                is_best = True
                best_eval_result = frame_recall_phrase_ko
                print('best test frame_recall_phrase_ko is %.4f at step %d' %
                      (frame_recall_phrase_ko, step))
        else:
            mAP_result = vhico_eval(cfg,
                                    eval_subset=eval_subset,
                                    eval_input=eval_input,
                                    GT_PATH_TEST=GT_PATH_TEST,
                                    GT_PATH_UNSEEN=GT_PATH_UNSEEN)

    ## set the model to training mode
    maskRCNN.train()

    return is_best, best_eval_result
예제 #15
0
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)
        # logging.info(f"roi belong to which image: shape {rpn_ret['rois'][:, 0:1].shape}\
        #  \n {rpn_ret['rois'][:, 0]}")
        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            if self.weak_supervise:
                cls_score, det_score, bbox_pred = self.Box_Outs(box_feat)
            else:
                cls_score, bbox_pred = self.Box_Outs(box_feat)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training and not self.weak_supervise:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'],
                rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                    mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_keypoints_int32=rpn_ret[
                            'roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        elif self.training and self.weak_supervise:
            # Weak supervision image-level loss
            # logging.info(f"image-level labels: shape {rpn_ret['image_labels_vec'].shape}\n {rpn_ret['image_labels_vec']}")
            # logging.info(f"cls score: shape {cls_score.shape}\n {cls_score}")
            # logging.info(f"det score: shape {det_score.shape}\n {det_score}")
            #
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            image_loss_cls, acc_score, reg = fast_rcnn_heads.image_level_loss(
                cls_score, det_score, rpn_ret['rois'],
                rpn_ret['image_labels_vec'], self.bceloss, box_feat)
            return_dict['losses']['image_loss_cls'] = image_loss_cls
            return_dict['losses']['spatial_reg'] = reg

            return_dict['metrics']['accuracy_cls'] = acc_score

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        return return_dict
    def _forward(self,
                 data,
                 im_info,
                 do_vis=False,
                 dataset_name=None,
                 roidb=None,
                 use_gt_labels=False,
                 **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))
        if dataset_name is not None:
            dataset_name = blob_utils.deserialize(dataset_name)
        else:
            dataset_name = cfg.TRAIN.DATASETS[
                0] if self.training else cfg.TEST.DATASETS[
                    0]  # assuming only one dataset per run

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)
        if not cfg.MODEL.USE_REL_PYRAMID:
            blob_conv_prd = self.Prd_RCNN.Conv_Body(im_data)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]
            if not cfg.MODEL.USE_REL_PYRAMID:
                blob_conv_prd = blob_conv_prd[-self.num_roi_levels:]
            else:
                blob_conv_prd = self.RelPyramid(blob_conv)

        if cfg.MODEL.SHARE_RES5 and self.training:
            box_feat, res5_feat = self.Box_Head(blob_conv,
                                                rpn_ret,
                                                use_relu=True)
        else:
            box_feat = self.Box_Head(blob_conv, rpn_ret, use_relu=True)
        cls_score, bbox_pred = self.Box_Outs(box_feat)

        # now go through the predicate branch
        use_relu = False if cfg.MODEL.NO_FC7_RELU else True
        if self.training:
            fg_inds = np.where(rpn_ret['labels_int32'] > 0)[0]
            det_rois = rpn_ret['rois'][fg_inds]
            det_labels = rpn_ret['labels_int32'][fg_inds]
            det_scores = F.softmax(cls_score[fg_inds], dim=1)
            rel_ret = self.RelPN(det_rois, det_labels, det_scores, im_info,
                                 dataset_name, roidb)
            if cfg.MODEL.ADD_SO_SCORES:
                sbj_feat = self.S_Head(blob_conv,
                                       rel_ret,
                                       rois_name='sbj_rois',
                                       use_relu=use_relu)
                obj_feat = self.O_Head(blob_conv,
                                       rel_ret,
                                       rois_name='obj_rois',
                                       use_relu=use_relu)
            else:
                sbj_feat = self.Box_Head(blob_conv,
                                         rel_ret,
                                         rois_name='sbj_rois',
                                         use_relu=use_relu)
                obj_feat = self.Box_Head(blob_conv,
                                         rel_ret,
                                         rois_name='obj_rois',
                                         use_relu=use_relu)
            if cfg.MODEL.USE_NODE_CONTRASTIVE_LOSS or cfg.MODEL.USE_NODE_CONTRASTIVE_SO_AWARE_LOSS or cfg.MODEL.USE_NODE_CONTRASTIVE_P_AWARE_LOSS:
                if cfg.MODEL.ADD_SO_SCORES:
                    # sbj
                    sbj_feat_sbj_pos = self.S_Head(
                        blob_conv,
                        rel_ret,
                        rois_name='sbj_rois_sbj_pos',
                        use_relu=use_relu)
                    obj_feat_sbj_pos = self.O_Head(
                        blob_conv,
                        rel_ret,
                        rois_name='obj_rois_sbj_pos',
                        use_relu=use_relu)
                    # obj
                    sbj_feat_obj_pos = self.S_Head(
                        blob_conv,
                        rel_ret,
                        rois_name='sbj_rois_obj_pos',
                        use_relu=use_relu)
                    obj_feat_obj_pos = self.O_Head(
                        blob_conv,
                        rel_ret,
                        rois_name='obj_rois_obj_pos',
                        use_relu=use_relu)
                else:
                    # sbj
                    sbj_feat_sbj_pos = self.Box_Head(
                        blob_conv,
                        rel_ret,
                        rois_name='sbj_rois_sbj_pos',
                        use_relu=use_relu)
                    obj_feat_sbj_pos = self.Box_Head(
                        blob_conv,
                        rel_ret,
                        rois_name='obj_rois_sbj_pos',
                        use_relu=use_relu)
                    # obj
                    sbj_feat_obj_pos = self.Box_Head(
                        blob_conv,
                        rel_ret,
                        rois_name='sbj_rois_obj_pos',
                        use_relu=use_relu)
                    obj_feat_obj_pos = self.Box_Head(
                        blob_conv,
                        rel_ret,
                        rois_name='obj_rois_obj_pos',
                        use_relu=use_relu)
        else:
            if roidb is not None:
                im_scale = im_info.data.numpy()[:, 2][0]
                im_w = im_info.data.numpy()[:, 1][0]
                im_h = im_info.data.numpy()[:, 0][0]
                sbj_boxes = roidb['sbj_gt_boxes']
                obj_boxes = roidb['obj_gt_boxes']
                sbj_rois = sbj_boxes * im_scale
                obj_rois = obj_boxes * im_scale
                repeated_batch_idx = 0 * blob_utils.ones(
                    (sbj_rois.shape[0], 1))
                sbj_rois = np.hstack((repeated_batch_idx, sbj_rois))
                obj_rois = np.hstack((repeated_batch_idx, obj_rois))
                rel_rois = box_utils_rel.rois_union(sbj_rois, obj_rois)
                rel_ret = {}
                rel_ret['sbj_rois'] = sbj_rois
                rel_ret['obj_rois'] = obj_rois
                rel_ret['rel_rois'] = rel_rois
                if cfg.FPN.FPN_ON and cfg.FPN.MULTILEVEL_ROIS:
                    lvl_min = cfg.FPN.ROI_MIN_LEVEL
                    lvl_max = cfg.FPN.ROI_MAX_LEVEL
                    rois_blob_names = ['sbj_rois', 'obj_rois', 'rel_rois']
                    for rois_blob_name in rois_blob_names:
                        # Add per FPN level roi blobs named like: <rois_blob_name>_fpn<lvl>
                        target_lvls = fpn_utils.map_rois_to_fpn_levels(
                            rel_ret[rois_blob_name][:, 1:5], lvl_min, lvl_max)
                        fpn_utils.add_multilevel_roi_blobs(
                            rel_ret, rois_blob_name, rel_ret[rois_blob_name],
                            target_lvls, lvl_min, lvl_max)
                sbj_det_feat = self.Box_Head(blob_conv,
                                             rel_ret,
                                             rois_name='sbj_rois',
                                             use_relu=True)
                sbj_cls_scores, _ = self.Box_Outs(sbj_det_feat)
                sbj_cls_scores = sbj_cls_scores.data.cpu().numpy()
                obj_det_feat = self.Box_Head(blob_conv,
                                             rel_ret,
                                             rois_name='obj_rois',
                                             use_relu=True)
                obj_cls_scores, _ = self.Box_Outs(obj_det_feat)
                obj_cls_scores = obj_cls_scores.data.cpu().numpy()
                if use_gt_labels:
                    sbj_labels = roidb['sbj_gt_classes']  # start from 0
                    obj_labels = roidb['obj_gt_classes']  # start from 0
                    sbj_scores = np.ones_like(sbj_labels, dtype=np.float32)
                    obj_scores = np.ones_like(obj_labels, dtype=np.float32)
                else:
                    sbj_labels = np.argmax(sbj_cls_scores[:, 1:], axis=1)
                    obj_labels = np.argmax(obj_cls_scores[:, 1:], axis=1)
                    sbj_scores = np.amax(sbj_cls_scores[:, 1:], axis=1)
                    obj_scores = np.amax(obj_cls_scores[:, 1:], axis=1)
                rel_ret['sbj_scores'] = sbj_scores.astype(np.float32,
                                                          copy=False)
                rel_ret['obj_scores'] = obj_scores.astype(np.float32,
                                                          copy=False)
                rel_ret['sbj_labels'] = sbj_labels.astype(
                    np.int32, copy=False) + 1  # need to start from 1
                rel_ret['obj_labels'] = obj_labels.astype(
                    np.int32, copy=False) + 1  # need to start from 1
                rel_ret['all_sbj_labels_int32'] = sbj_labels.astype(np.int32,
                                                                    copy=False)
                rel_ret['all_obj_labels_int32'] = obj_labels.astype(np.int32,
                                                                    copy=False)
                if cfg.MODEL.USE_SPATIAL_FEAT:
                    spt_feat = box_utils_rel.get_spt_features(
                        sbj_boxes, obj_boxes, im_w, im_h)
                    rel_ret['spt_feat'] = spt_feat
                if cfg.MODEL.ADD_SO_SCORES:
                    sbj_feat = self.S_Head(blob_conv,
                                           rel_ret,
                                           rois_name='sbj_rois',
                                           use_relu=use_relu)
                    obj_feat = self.O_Head(blob_conv,
                                           rel_ret,
                                           rois_name='obj_rois',
                                           use_relu=use_relu)
                else:
                    sbj_feat = self.Box_Head(blob_conv,
                                             rel_ret,
                                             rois_name='sbj_rois',
                                             use_relu=use_relu)
                    obj_feat = self.Box_Head(blob_conv,
                                             rel_ret,
                                             rois_name='obj_rois',
                                             use_relu=use_relu)
            else:
                score_thresh = cfg.TEST.SCORE_THRESH
                while score_thresh >= -1e-06:  # a negative value very close to 0.0
                    det_rois, det_labels, det_scores = \
                        self.prepare_det_rois(rpn_ret['rois'], cls_score, bbox_pred, im_info, score_thresh)
                    rel_ret = self.RelPN(det_rois, det_labels, det_scores,
                                         im_info, dataset_name, roidb)
                    valid_len = len(rel_ret['rel_rois'])
                    if valid_len > 0:
                        break
                    logger.info(
                        'Got {} rel_rois when score_thresh={}, changing to {}'.
                        format(valid_len, score_thresh, score_thresh - 0.01))
                    score_thresh -= 0.01
                if cfg.MODEL.ADD_SO_SCORES:
                    det_s_feat = self.S_Head(blob_conv,
                                             rel_ret,
                                             rois_name='det_rois',
                                             use_relu=use_relu)
                    det_o_feat = self.O_Head(blob_conv,
                                             rel_ret,
                                             rois_name='det_rois',
                                             use_relu=use_relu)
                    sbj_feat = det_s_feat[rel_ret['sbj_inds']]
                    obj_feat = det_o_feat[rel_ret['obj_inds']]
                else:
                    det_feat = self.Box_Head(blob_conv,
                                             rel_ret,
                                             rois_name='det_rois',
                                             use_relu=use_relu)
                    sbj_feat = det_feat[rel_ret['sbj_inds']]
                    obj_feat = det_feat[rel_ret['obj_inds']]

        rel_feat = self.Prd_RCNN.Box_Head(blob_conv_prd,
                                          rel_ret,
                                          rois_name='rel_rois',
                                          use_relu=use_relu)

        spo_feat = torch.cat((sbj_feat, rel_feat, obj_feat), dim=1)
        if cfg.MODEL.USE_SPATIAL_FEAT:
            spt_feat = rel_ret['spt_feat']
        else:
            spt_feat = None
        if cfg.MODEL.USE_FREQ_BIAS or cfg.MODEL.RUN_BASELINE:
            sbj_labels = rel_ret['all_sbj_labels_int32']
            obj_labels = rel_ret['all_obj_labels_int32']
        else:
            sbj_labels = None
            obj_labels = None

        # prd_scores is the visual scores. See reldn_heads.py
        prd_scores, prd_bias_scores, prd_spt_scores, ttl_cls_scores, sbj_cls_scores, obj_cls_scores = \
            self.RelDN(spo_feat, spt_feat, sbj_labels, obj_labels, sbj_feat, obj_feat)

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox
            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'],
                rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            if cfg.MODEL.USE_FREQ_BIAS and not cfg.MODEL.ADD_SCORES_ALL:
                loss_cls_bias, accuracy_cls_bias = reldn_heads.reldn_losses(
                    prd_bias_scores, rel_ret['all_prd_labels_int32'])
                return_dict['losses']['loss_cls_bias'] = loss_cls_bias
                return_dict['metrics']['accuracy_cls_bias'] = accuracy_cls_bias
            if cfg.MODEL.USE_SPATIAL_FEAT and not cfg.MODEL.ADD_SCORES_ALL:
                loss_cls_spt, accuracy_cls_spt = reldn_heads.reldn_losses(
                    prd_spt_scores, rel_ret['all_prd_labels_int32'])
                return_dict['losses']['loss_cls_spt'] = loss_cls_spt
                return_dict['metrics']['accuracy_cls_spt'] = accuracy_cls_spt
            if cfg.MODEL.ADD_SCORES_ALL:
                loss_cls_ttl, accuracy_cls_ttl = reldn_heads.reldn_losses(
                    ttl_cls_scores, rel_ret['all_prd_labels_int32'])
                return_dict['losses']['loss_cls_ttl'] = loss_cls_ttl
                return_dict['metrics']['accuracy_cls_ttl'] = accuracy_cls_ttl
            else:
                loss_cls_prd, accuracy_cls_prd = reldn_heads.reldn_losses(
                    prd_scores, rel_ret['all_prd_labels_int32'])
                return_dict['losses']['loss_cls_prd'] = loss_cls_prd
                return_dict['metrics']['accuracy_cls_prd'] = accuracy_cls_prd
            if cfg.MODEL.USE_NODE_CONTRASTIVE_LOSS or cfg.MODEL.USE_NODE_CONTRASTIVE_SO_AWARE_LOSS or cfg.MODEL.USE_NODE_CONTRASTIVE_P_AWARE_LOSS:
                # sbj
                rel_feat_sbj_pos = self.Prd_RCNN.Box_Head(
                    blob_conv_prd,
                    rel_ret,
                    rois_name='rel_rois_sbj_pos',
                    use_relu=use_relu)
                spo_feat_sbj_pos = torch.cat(
                    (sbj_feat_sbj_pos, rel_feat_sbj_pos, obj_feat_sbj_pos),
                    dim=1)
                if cfg.MODEL.USE_SPATIAL_FEAT:
                    spt_feat_sbj_pos = rel_ret['spt_feat_sbj_pos']
                else:
                    spt_feat_sbj_pos = None
                if cfg.MODEL.USE_FREQ_BIAS or cfg.MODEL.RUN_BASELINE:
                    sbj_labels_sbj_pos_fg = rel_ret[
                        'sbj_labels_sbj_pos_fg_int32']
                    obj_labels_sbj_pos_fg = rel_ret[
                        'obj_labels_sbj_pos_fg_int32']
                else:
                    sbj_labels_sbj_pos_fg = None
                    obj_labels_sbj_pos_fg = None
                _, prd_bias_scores_sbj_pos, _, ttl_cls_scores_sbj_pos, _, _ = \
                    self.RelDN(spo_feat_sbj_pos, spt_feat_sbj_pos, sbj_labels_sbj_pos_fg, obj_labels_sbj_pos_fg, sbj_feat_sbj_pos, obj_feat_sbj_pos)
                # obj
                rel_feat_obj_pos = self.Prd_RCNN.Box_Head(
                    blob_conv_prd,
                    rel_ret,
                    rois_name='rel_rois_obj_pos',
                    use_relu=use_relu)
                spo_feat_obj_pos = torch.cat(
                    (sbj_feat_obj_pos, rel_feat_obj_pos, obj_feat_obj_pos),
                    dim=1)
                if cfg.MODEL.USE_SPATIAL_FEAT:
                    spt_feat_obj_pos = rel_ret['spt_feat_obj_pos']
                else:
                    spt_feat_obj_pos = None
                if cfg.MODEL.USE_FREQ_BIAS or cfg.MODEL.RUN_BASELINE:
                    sbj_labels_obj_pos_fg = rel_ret[
                        'sbj_labels_obj_pos_fg_int32']
                    obj_labels_obj_pos_fg = rel_ret[
                        'obj_labels_obj_pos_fg_int32']
                else:
                    sbj_labels_obj_pos_fg = None
                    obj_labels_obj_pos_fg = None
                _, prd_bias_scores_obj_pos, _, ttl_cls_scores_obj_pos, _, _ = \
                    self.RelDN(spo_feat_obj_pos, spt_feat_obj_pos, sbj_labels_obj_pos_fg, obj_labels_obj_pos_fg, sbj_feat_obj_pos, obj_feat_obj_pos)
                if cfg.MODEL.USE_NODE_CONTRASTIVE_LOSS:
                    loss_contrastive_sbj, loss_contrastive_obj = reldn_heads.reldn_contrastive_losses(
                        ttl_cls_scores_sbj_pos, ttl_cls_scores_obj_pos,
                        rel_ret)
                    return_dict['losses'][
                        'loss_contrastive_sbj'] = loss_contrastive_sbj * cfg.MODEL.NODE_CONTRASTIVE_WEIGHT
                    return_dict['losses'][
                        'loss_contrastive_obj'] = loss_contrastive_obj * cfg.MODEL.NODE_CONTRASTIVE_WEIGHT
                if cfg.MODEL.USE_NODE_CONTRASTIVE_SO_AWARE_LOSS:
                    loss_so_contrastive_sbj, loss_so_contrastive_obj = reldn_heads.reldn_so_contrastive_losses(
                        ttl_cls_scores_sbj_pos, ttl_cls_scores_obj_pos,
                        rel_ret)
                    return_dict['losses'][
                        'loss_so_contrastive_sbj'] = loss_so_contrastive_sbj * cfg.MODEL.NODE_CONTRASTIVE_SO_AWARE_WEIGHT
                    return_dict['losses'][
                        'loss_so_contrastive_obj'] = loss_so_contrastive_obj * cfg.MODEL.NODE_CONTRASTIVE_SO_AWARE_WEIGHT
                if cfg.MODEL.USE_NODE_CONTRASTIVE_P_AWARE_LOSS:
                    loss_p_contrastive_sbj, loss_p_contrastive_obj = reldn_heads.reldn_p_contrastive_losses(
                        ttl_cls_scores_sbj_pos, ttl_cls_scores_obj_pos,
                        prd_bias_scores_sbj_pos, prd_bias_scores_obj_pos,
                        rel_ret)
                    return_dict['losses'][
                        'loss_p_contrastive_sbj'] = loss_p_contrastive_sbj * cfg.MODEL.NODE_CONTRASTIVE_P_AWARE_WEIGHT
                    return_dict['losses'][
                        'loss_p_contrastive_obj'] = loss_p_contrastive_obj * cfg.MODEL.NODE_CONTRASTIVE_P_AWARE_WEIGHT

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)
        else:
            # Testing
            return_dict['sbj_rois'] = rel_ret['sbj_rois']
            return_dict['obj_rois'] = rel_ret['obj_rois']
            return_dict['sbj_labels'] = rel_ret['sbj_labels']
            return_dict['obj_labels'] = rel_ret['obj_labels']
            return_dict['sbj_scores'] = rel_ret['sbj_scores']
            return_dict['obj_scores'] = rel_ret['obj_scores']
            return_dict['prd_scores'] = prd_scores
            if cfg.MODEL.USE_FREQ_BIAS:
                return_dict['prd_scores_bias'] = prd_bias_scores
            if cfg.MODEL.USE_SPATIAL_FEAT:
                return_dict['prd_scores_spt'] = prd_spt_scores
            if cfg.MODEL.ADD_SCORES_ALL:
                return_dict['prd_ttl_scores'] = ttl_cls_scores
            if do_vis:
                return_dict['blob_conv'] = blob_conv
                return_dict['blob_conv_prd'] = blob_conv_prd

        return return_dict
예제 #17
0
    def _forward(self, data, query, im_info, query_type, roidb=None, **rpn_kwargs):

        #query_type = query_type.item()

        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        # feed image data to base model to obtain base feature map
        blob_conv = self.Conv_Body(im_data)

        query_conv = []
        shot = len(query)
        for i in range(shot):
            query_conv.append(self.Conv_Body(query[i]))
        
        def pooling(feats, method='avg', dim = 0):
            feats = torch.stack(feats)
            if method == 'avg':
                feat = torch.mean(feats, dim=dim)
            elif method == 'max':
                feat, _ = torch.max(feats, dim=dim)
            return feat
        
        if cfg.FPN.FPN_ON:
            query_conv = list(map(list, zip(*query_conv)))
            query_conv = [pooling(QC_i, method='avg', dim = 0) for QC_i in query_conv]

            rpn_feat = []
            act_feat = []
            act_aim = []
            c_weight = []
            
            if len(blob_conv) == 5:
                start_match_idx = 1
                #sa_blob_conv = self.sa(blob_conv[0])
                sa_blob_conv = blob_conv[0]
                rpn_feat.append(sa_blob_conv)
            else:
                start_match_idx = 0

            for IP, QP in zip(blob_conv[start_match_idx:], query_conv[start_match_idx:]):
                _rpn_feat, _act_feat, _act_aim, _c_weight = self.match_net(IP, QP)
                rpn_feat.append(_rpn_feat)
                act_feat.append(_act_feat)
                act_aim.append(_act_aim)
                c_weight.append(_c_weight)
                """
                correlation = []
                QP_pool = self.global_avgpool(QP)
                #IP_sa = self.sa(IP)
                IP_sa = IP
                for IP_sa_batch, QP_pool_batch in zip(IP_sa, QP_pool):
                    IP_sa_batch, QP_pool_batch = IP_sa_batch.unsqueeze(0), QP_pool_batch.unsqueeze(0)
                    correlation.append(F.conv2d(IP_sa_batch, QP_pool_batch.permute(1,0,2,3), groups=IP.shape[1]).squeeze(0))
                correlation = torch.stack(correlation)
                rpn_feat.append(correlation)
                act_feat.append(correlation)
                act_aim.append(QP)
                """
        else:
            query_conv = pooling(query_conv)
            rpn_feat, act_feat, act_aim, c_weight = self.match_net(blob_conv, query_conv)
            """
            correlation = []
            QP_pool = self.global_avgpool(QP)
            #IP_sa = self.sa(IP)
            IP_sa = IP
            for IP_sa_batch, QP_pool_batch in zip(IP_sa, QP_pool):
                IP_sa_batch, QP_pool_batch = IP_sa_batch.unsqueeze(0), QP_pool_batch.unsqueeze(0)
                correlation.append(F.conv2d(IP_sa_batch, QP_pool_batch.permute(1,0,2,3), groups=IP.shape[1]).squeeze(0))
            correlation = torch.stack(correlation)
            rpn_feat = correlation
            act_feat = correlation
            act_aim = query_conv
            """
        rpn_ret = self.RPN(rpn_feat, im_info, roidb)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv
            return_dict['query_conv'] = query_conv

        if not cfg.MODEL.RPN_ONLY:
            if not cfg.FPN.FPN_ON:
                if cfg.MODEL.SHARE_RES5 and self.training:
                    if cfg.RELATION_RCNN:
                        box_feat, query_box_feat, res5_feat, query_res5_feat = self.Box_Head(act_feat, act_aim, rpn_ret)
                    else:
                        box_feat, res5_feat = self.Box_Head(act_feat, rpn_ret)
                else:
                    if cfg.RELATION_RCNN:
                        box_feat, query_box_feat= self.Box_Head(act_feat, act_aim, rpn_ret)
                    else:
                        box_feat = self.Box_Head(act_feat, rpn_ret)
                
                if cfg.RELATION_RCNN:
                    cls_score, bbox_pred = self.Box_Outs(box_feat, query_box_feat)
                else:
                    cls_score, bbox_pred = self.Box_Outs(box_feat)
            else:
                if cfg.RELATION_RCNN:
                    box_feat, query_box_feat = self.Box_Head(act_feat, act_aim, rpn_ret)
                    cls_score, bbox_pred = self.Box_Outs(box_feat, query_box_feat)
                else:
                    box_feat = self.Box_Head(act_feat, rpn_ret)
                    cls_score, bbox_pred = self.Box_Outs(box_feat)

        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(dict(
                (k, rpn_ret[k]) for k in rpn_ret.keys()
                if (k.startswith('rpn_cls_logits') or k.startswith('rpn_bbox_pred'))
            ))
            #rpn_kwargs.update({'query_type': query_type})
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(**rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' % lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' % lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls, _ = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'],
                rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights'], rpn_ret['rois'], query_type, use_marginloss=False)
            #return_dict['losses']['margin_loss'] = margin_loss
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    if cfg.RELATION_RCNN:
                        mask_feat = self.Mask_Head(res5_feat, query_res5_feat, rpn_ret,
                                                roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                    else:
                        mask_feat = self.Mask_Head(res5_feat, rpn_ret, roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    if cfg.RELATION_RCNN:
                        mask_feat = self.Mask_Head(act_feat, act_aim, rpn_ret)
                    else:
                        mask_feat = self.Mask_Head(act_feat, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                #loss_mask = mask_rcnn_heads.mask_rcnn_losses(mask_pred, rpn_ret['masks_int32'], rpn_ret['mask_rois'], query_type)
                #return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(res5_feat, rpn_ret,
                                                  roi_has_keypoints_int32=rpn_ret['roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)
            
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        return return_dict
예제 #18
0
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()
        return_dict = {}  # A dict to collect return variables

        if cfg.LESION.USE_POSITION:
            blob_conv, res5_feat = self.Conv_Body(im_data)
        elif cfg.LESION.SHALLOW_POSITION:
            blob_conv, pos_cls_pred, pos_reg_pred = self.Conv_Body(im_data)
        else:
            blob_conv = self.Conv_Body(im_data)

        if cfg.MODEL.LR_VIEW_ON or cfg.MODEL.GIF_ON or cfg.MODEL.LRASY_MAHA_ON:
            blob_conv = self._get_lrview_blob_conv(blob_conv)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
            # print cls_score.shape
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

            if cfg.LESION.USE_POSITION:
                position_feat = self.Position_Head(res5_feat)
                pos_cls_pred = self.Position_Cls_Outs(position_feat)
                pos_reg_pred = self.Position_Reg_Outs(position_feat)
                return_dict['pos_cls_pred'] = pos_cls_pred
                return_dict['pos_reg_pred'] = pos_reg_pred
            if cfg.LESION.SHALLOW_POSITION:
                return_dict['pos_cls_pred'] = pos_cls_pred
                return_dict['pos_reg_pred'] = pos_reg_pred
            if cfg.LESION.POSITION_RCNN:
                pos_cls_rcnn, pos_reg_rcnn = self.Position_RCNN(box_feat)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            if cfg.MODEL.FASTER_RCNN:
                # bbox loss
                loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                    cls_score, bbox_pred, rpn_ret['labels_int32'],
                    rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                    rpn_ret['bbox_outside_weights'])
                return_dict['losses']['loss_cls'] = loss_cls
                return_dict['losses']['loss_bbox'] = loss_bbox
                return_dict['metrics']['accuracy_cls'] = accuracy_cls
            # RCNN Position
            if cfg.LESION.POSITION_RCNN:
                pos_cls_rcnn_loss, pos_reg_rcnn_loss, accuracy_position_rcnn = position_rcnn_losses(
                    pos_cls_rcnn, pos_reg_rcnn, roidb)
                return_dict['losses']['RcnnPosCls_loss'] = pos_cls_rcnn_loss
                return_dict['losses']['RcnnPosReg_loss'] = pos_reg_rcnn_loss
                return_dict['metrics'][
                    'accuracy_position_rcnn'] = accuracy_position_rcnn
            # Shallow Position Branch
            elif cfg.LESION.SHALLOW_POSITION:
                pos_cls_loss, pos_reg_loss, accuracy_position = position_losses(
                    pos_cls_pred, pos_reg_pred, roidb)
                #pos_reg_loss = position_reg_losses(reg_pred, roidb)
                return_dict['losses']['pos_cls_loss'] = pos_cls_loss
                return_dict['losses']['pos_reg_loss'] = pos_reg_loss
                return_dict['metrics']['accuracy_position'] = accuracy_position
            # Position Branch
            elif cfg.LESION.USE_POSITION:
                position_feat = self.Position_Head(res5_feat)
                pos_cls_pred = self.Position_Cls_Outs(position_feat)
                pos_reg_pred = self.Position_Reg_Outs(position_feat)
                pos_cls_loss, pos_reg_loss, accuracy_position = position_losses(
                    pos_cls_pred, pos_reg_pred, roidb)
                #pos_reg_loss = position_reg_losses(reg_pred, roidb)
                return_dict['losses']['pos_cls_loss'] = pos_cls_loss
                return_dict['losses']['pos_reg_loss'] = pos_reg_loss
                return_dict['metrics']['accuracy_position'] = accuracy_position

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                    mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_keypoints_int32=rpn_ret[
                            'roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred
        return return_dict
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(dict(
                (k, rpn_ret[k]) for k in rpn_ret.keys()
                if (k.startswith('rpn_cls_logits') or k.startswith('rpn_bbox_pred'))
            ))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(**rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' % lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' % lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'],
                rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(res5_feat, rpn_ret,
                                               roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(res5_feat, rpn_ret,
                                                  roi_has_keypoints_int32=rpn_ret['roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        return return_dict
예제 #20
0
    def _forward(self, data, im_info, roidb=None, rois=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        if cfg.MODEL.RELATION_NET_INPUT == 'GEO' and not cfg.REL_INFER.TRAIN and cfg.MODEL.NUM_RELATIONS > 0 and self.training:
            blob_conv = im_data
        else:
            blob_conv = self.Conv_Body(im_data)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)
        # If rpn_ret doesn't have rpn_ret, the rois should have been given, we use that.
        if 'rois' not in rpn_ret:
            rpn_ret['rois'] = rois
        if hasattr(self, '_ignore_classes') and 'labels_int32' in rpn_ret:
            # Turn ignore classes labels to 0, because they are treated as background
            rpn_ret['labels_int32'][np.isin(rpn_ret['labels_int32'],
                                            self._ignore_classes)] = 0

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                if cfg.MODEL.RELATION_NET_INPUT == 'GEO' and not cfg.REL_INFER.TRAIN and cfg.MODEL.NUM_RELATIONS > 0 and self.training:
                    box_feat = blob_conv.new_zeros(
                        rpn_ret['labels_int32'].shape[0],
                        self.Box_Head.dim_out)
                else:
                    box_feat = self.Box_Head(blob_conv, rpn_ret)
            if cfg.MODEL.RELATION_NET_INPUT == 'GEO' and not cfg.REL_INFER.TRAIN and cfg.MODEL.NUM_RELATIONS > 0 and self.training:
                cls_score = box_feat.new_zeros(
                    rpn_ret['labels_int32'].shape[0], cfg.MODEL.NUM_CLASSES)
                bbox_pred = box_feat.new_zeros(
                    rpn_ret['labels_int32'].shape[0],
                    4 * cfg.MODEL.NUM_CLASSES)
            else:
                cls_score, bbox_pred = self.Box_Outs(box_feat)
            if cfg.TEST.TAGGING:
                rois_label = torch.from_numpy(
                    rpn_ret['labels_int32'].astype('int64')).to(
                        cls_score.device)
                accuracy_cls = cls_score.max(1)[1].eq(rois_label).float().mean(
                    dim=0)
                print('Before refine:', accuracy_cls.item())

        if cfg.MODEL.NUM_RELATIONS > 0 and self.training:
            rel_scores, rel_labels = self.Rel_Outs(rpn_ret['rois'],
                                                   box_feat,
                                                   rpn_ret['labels_int32'],
                                                   roidb=roidb)
        elif cfg.MODEL.NUM_RELATIONS > 0 and cfg.TEST.USE_REL_INFER:
            if cfg.TEST.TAGGING:
                rel_scores = self.Rel_Outs(rpn_ret['rois'], box_feat)
                cls_score = self.Rel_Inf(cls_score, rel_scores, roidb=roidb)
            else:  # zero shot detection
                filt = (cls_score.topk(cfg.TEST.REL_INFER_PROPOSAL,
                                       1)[1] == 0).float().sum(1) == 0
                filt[cls_score[:,
                               1:].max(1)[0].topk(min(100, cls_score.shape[0]),
                                                  0)[1]] += 1
                filt = filt >= 2
                if filt.sum() == 0:
                    print('all background?')
                else:
                    tmp_rois = rpn_ret['rois'][filt.cpu().numpy().astype(
                        'bool')]

                    rel_scores = self.Rel_Outs(tmp_rois, box_feat[filt])
                    tmp_cls_score = self.Rel_Inf(cls_score[filt],
                                                 rel_scores,
                                                 roidb=None)
                    cls_score[filt] = tmp_cls_score
            if cfg.TEST.TAGGING:
                rois_label = torch.from_numpy(
                    rpn_ret['labels_int32'].astype('int64')).to(
                        cls_score.device)
                accuracy_cls = cls_score.max(1)[1].eq(rois_label).float().mean(
                    dim=0)
                print('After refine:', accuracy_cls.item())

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            if not cfg.MODEL.TAGGING:
                rpn_kwargs.update(
                    dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                         if (k.startswith('rpn_cls_logits')
                             or k.startswith('rpn_bbox_pred'))))
                loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                    **rpn_kwargs)
                if cfg.FPN.FPN_ON:
                    for i, lvl in enumerate(
                            range(cfg.FPN.RPN_MIN_LEVEL,
                                  cfg.FPN.RPN_MAX_LEVEL + 1)):
                        return_dict['losses']['loss_rpn_cls_fpn%d' %
                                              lvl] = loss_rpn_cls[i]
                        return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                              lvl] = loss_rpn_bbox[i]
                else:
                    return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                    return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            if not cfg.MODEL.RPN_ONLY:
                # bbox loss
                loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                    cls_score, bbox_pred, rpn_ret['labels_int32'],
                    rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                    rpn_ret['bbox_outside_weights'])
                return_dict['losses']['loss_cls'] = loss_cls
                return_dict['losses']['loss_bbox'] = loss_bbox
                return_dict['metrics']['accuracy_cls'] = accuracy_cls

                if hasattr(loss_cls, 'mean_similarity'):
                    return_dict['metrics'][
                        'mean_similarity'] = loss_cls.mean_similarity

                if cfg.FAST_RCNN.SAE_REGU:
                    tmp = cls_score.sae_loss[(rpn_ret['labels_int32'] >
                                              0).tolist()]
                    return_dict['losses']['loss_sae'] = 1e-4 * tmp.mean(
                    ) if tmp.numel() > 0 else tmp.new_tensor(0)
                    if torch.isnan(return_dict['losses']['loss_sae']):
                        import pdb
                        pdb.set_trace()

            if cfg.REL_INFER.TRAIN:
                return_dict['losses']['loss_rel_infer'] = cfg.REL_INFER.TRAIN_WEIGHT * \
                    self.Rel_Inf_Train(rpn_ret['rois'], rpn_ret['labels_int32'], cls_score, rel_scores, roidb=roidb)

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                    mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_keypoints_int32=rpn_ret[
                            'roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict.update(rpn_ret)
            if not cfg.MODEL.RPN_ONLY:
                if cfg.TEST.KEEP_HIGHEST:
                    cls_score = F.softmax(cls_score * 1e10, dim=1) * cls_score
                return_dict['cls_score'] = cls_score
                return_dict['bbox_pred'] = bbox_pred

        return return_dict
예제 #21
0
    def _forward(self,
                 data,
                 im_info,
                 do_vis=False,
                 dataset_name=None,
                 roidb=None,
                 use_gt_labels=False,
                 **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))
        if dataset_name is not None:
            dataset_name = blob_utils.deserialize(dataset_name)
        else:
            dataset_name = cfg.TRAIN.DATASETS[
                0] if self.training else cfg.TEST.DATASETS[
                    0]  # assuming only one dataset per run

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)

        if self.training:
            gt_rois = roidb[0]['boxes'] * im_info[0, 2].data.cpu().numpy()
            gt_classes = roidb[0]['gt_classes']
            sbj_gt_boxes = roidb[0]['sbj_gt_boxes']
            obj_gt_boxes = roidb[0]['obj_gt_boxes']

        rpn_ret = self.RPN(blob_conv, im_info, roidb)

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if cfg.MODEL.SHARE_RES5 and self.training:
            box_feat, res5_feat = self.Box_Head(blob_conv,
                                                rpn_ret,
                                                use_relu=True)
        else:
            box_feat = self.Box_Head(blob_conv, rpn_ret, use_relu=True)
        cls_score, bbox_pred = self.Box_Outs(box_feat)

        # now go through the predicate branch
        use_relu = False if cfg.MODEL.NO_FC7_RELU else True
        if self.training:
            fg_inds = np.where(rpn_ret['labels_int32'] > 0)[0]
            det_rois = rpn_ret['rois'][fg_inds]
            det_labels = rpn_ret['labels_int32'][fg_inds]
            det_scores = F.softmax(cls_score[fg_inds], dim=1)
            rel_ret = self.RelPN(det_rois, det_labels, det_scores, im_info,
                                 dataset_name, roidb)

            select_inds = np.array([])
            repeated_batch_idx = 0 * blob_utils.ones((gt_rois.shape[0], 1))
            select_rois = np.hstack((repeated_batch_idx, gt_rois))
            select_feat = self.detector_feature_map(blob_conv,
                                                    select_rois,
                                                    use_relu=True)
            select_dists, _ = self.Box_Outs(select_feat)
            select_dists = F.softmax(select_dists, -1)
            select_labels = select_dists[:,
                                         1:].max(-1)[1].data.cpu().numpy() + 1
            select_gt_labels = gt_classes

            sbj_feat = self.Box_Head_sg(blob_conv,
                                        rel_ret,
                                        rois_name='sbj_rois',
                                        use_relu=True)
            obj_feat = self.Box_Head_sg(blob_conv,
                                        rel_ret,
                                        rois_name='obj_rois',
                                        use_relu=True)

        else:
            if roidb is not None:
                im_scale = im_info.data.numpy()[:, 2][0]
                im_w = im_info.data.numpy()[:, 1][0]
                im_h = im_info.data.numpy()[:, 0][0]
                gt_rois = roidb['boxes'] * im_scale

                sbj_boxes = roidb['sbj_gt_boxes']
                obj_boxes = roidb['obj_gt_boxes']
                sbj_rois = sbj_boxes * im_scale
                obj_rois = obj_boxes * im_scale
                repeated_batch_idx = 0 * blob_utils.ones(
                    (sbj_rois.shape[0], 1))
                sbj_rois = np.hstack((repeated_batch_idx, sbj_rois))
                obj_rois = np.hstack((repeated_batch_idx, obj_rois))

                if gt_rois.size > 0:
                    repeated_batch_idx = 0 * blob_utils.ones(
                        (gt_rois.shape[0], 1))
                    select_rois = np.hstack((repeated_batch_idx, gt_rois))

                    select_feat = self.detector_feature_map(blob_conv,
                                                            select_rois,
                                                            use_relu=True)
                    select_dists, _ = self.Box_Outs(select_feat)
                    select_labels = self.get_nms_preds(select_dists,
                                                       select_rois,
                                                       softmax=False)
                    select_inds = np.arange(0, select_labels.shape[0]).astype(
                        np.int64)

                    rel_ret = self.EdgePN(select_rois, select_labels,
                                          select_dists, im_info, dataset_name,
                                          None)

                    det_feat_sg = self.Box_Head_sg(blob_conv,
                                                   rel_ret,
                                                   rois_name='det_rois',
                                                   use_relu=True)

                    det_labels = select_labels.copy()
                    det_scores = select_dists[:, 1:].max(
                        -1)[0].data.cpu().numpy()
                    min_ious = np.minimum(
                        box_utils.bbox_overlaps(
                            select_rois[:, 1:][rel_ret['sbj_inds']],
                            sbj_rois[:, 1:]),
                        box_utils.bbox_overlaps(
                            select_rois[:, 1:][rel_ret['obj_inds']],
                            obj_rois[:, 1:]))
                    match_indices = np.where(min_ious.max(-1) >= 0.5)[0]
                    rel_ret['sbj_inds'], rel_ret['obj_inds'], rel_ret['sbj_rois'], rel_ret['obj_rois'],\
                    rel_ret['rel_rois'], rel_ret['sbj_labels'], rel_ret['obj_labels'], rel_ret['sbj_scores'], \
                    rel_ret['obj_scores'] = rel_ret['sbj_inds'][match_indices], \
                    rel_ret['obj_inds'][match_indices], rel_ret['sbj_rois'][match_indices], \
                    rel_ret['obj_rois'][match_indices], rel_ret['rel_rois'][match_indices], \
                    rel_ret['sbj_labels'][match_indices], rel_ret['obj_labels'][match_indices], \
                    rel_ret['sbj_scores'][match_indices], rel_ret['obj_scores'][match_indices]

                    sbj_feat = det_feat_sg[rel_ret['sbj_inds']]
                    obj_feat = det_feat_sg[rel_ret['obj_inds']]

                else:
                    score_thresh = cfg.TEST.SCORE_THRESH
                    while score_thresh >= -1e-06:  # a negative value very close to 0.0
                        det_rois, det_labels, det_scores = \
                            self.prepare_det_rois(rpn_ret['rois'], cls_score, bbox_pred, im_info, score_thresh)
                        rel_ret = self.RelPN(det_rois, det_labels, det_scores,
                                             im_info, dataset_name, None)
                        valid_len = len(rel_ret['rel_rois'])
                        if valid_len > 0:
                            break
                        logger.info(
                            'Got {} rel_rois when score_thresh={}, changing to {}'
                            .format(valid_len, score_thresh,
                                    score_thresh - 0.01))
                        score_thresh -= 0.01
                    det_feat = None
                    # #
                    vaild_inds = np.unique(
                        np.concatenate(
                            (rel_ret['sbj_inds'], rel_ret['obj_inds']), 0))
                    vaild_sort_inds = vaild_inds[np.argsort(
                        -det_scores[vaild_inds])]

                    select_inds = vaild_sort_inds[:10]
                    select_rois = det_rois[select_inds]

                    det_feat = self.Box_Head(blob_conv,
                                             rel_ret,
                                             rois_name='det_rois',
                                             use_relu=True)
                    det_dists, _ = self.Box_Outs(det_feat)
                    select_dists = det_dists[select_inds]
                    select_labels = det_labels[select_inds].copy()
            else:
                score_thresh = cfg.TEST.SCORE_THRESH
                while score_thresh >= -1e-06:  # a negative value very close to 0.0
                    det_rois, det_labels, det_scores = \
                        self.prepare_det_rois(rpn_ret['rois'], cls_score, bbox_pred, im_info, score_thresh)
                    rel_ret = self.RelPN(det_rois, det_labels, det_scores,
                                         im_info, dataset_name, roidb)
                    valid_len = len(rel_ret['rel_rois'])
                    if valid_len > 0:
                        break
                    logger.info(
                        'Got {} rel_rois when score_thresh={}, changing to {}'.
                        format(valid_len, score_thresh, score_thresh - 0.01))
                    score_thresh -= 0.01
                det_feat = None
                vaild_inds = np.unique(
                    np.concatenate((rel_ret['sbj_inds'], rel_ret['obj_inds']),
                                   0))

                vaild_sort_inds = vaild_inds[np.argsort(
                    -det_scores[vaild_inds])]

                select_inds = vaild_sort_inds
                select_rois = det_rois[select_inds]

                det_feat_sg = self.Box_Head_sg(blob_conv,
                                               rel_ret,
                                               rois_name='det_rois',
                                               use_relu=True)
                sbj_feat = det_feat_sg[rel_ret['sbj_inds']]
                obj_feat = det_feat_sg[rel_ret['obj_inds']]

                if det_feat is None:
                    det_feat = self.Box_Head(blob_conv,
                                             rel_ret,
                                             rois_name='det_rois',
                                             use_relu=True)
                det_dists, _ = self.Box_Outs(det_feat)
                select_dists = det_dists[select_inds]
                select_labels = det_labels[select_inds].copy()

        if select_inds.size > 2 or self.training:
            # if False:
            entity_fmap = self.obj_feature_map(blob_conv.detach(),
                                               select_rois,
                                               use_relu=True)
            entity_feat0 = self.merge_obj_feats(entity_fmap, select_rois,
                                                select_dists.detach(), im_info)
            edge_ret = self.EdgePN(select_rois, select_labels, select_dists,
                                   im_info, dataset_name, None)
            edge_feat = self.get_phr_feats(
                self.visual_rep(blob_conv,
                                edge_ret,
                                device_id,
                                use_relu=use_relu))
            edge_inds = np.stack((edge_ret['sbj_rois'][:, 0].astype(edge_ret['sbj_inds'].dtype), \
                                      edge_ret['sbj_inds'], edge_ret['obj_inds']), -1)

            im_inds = select_rois[:, 0].astype(edge_inds.dtype)
            entity_feat = self.obj_mps1(entity_feat0, edge_feat, im_inds,
                                        edge_inds)
            entity_feat = self.obj_mps2(entity_feat, edge_feat, im_inds,
                                        edge_inds)

            entity_cls_score = self.ObjClassifier(entity_feat)

            if not self.training:
                select_labels_pred = self.get_nms_preds(
                    entity_cls_score, select_rois)

                det_labels[select_inds] = select_labels_pred
                if use_gt_labels:
                    det_labels[select_inds] = roidb['gt_classes']
                select_twod_inds = np.arange(0, select_labels_pred.shape[
                    0]) * cfg.MODEL.NUM_CLASSES + select_labels_pred
                select_scores = F.softmax(
                    entity_cls_score,
                    -1).view(-1)[select_twod_inds].data.cpu().numpy()

                det_scores[select_inds] = select_scores
                if use_gt_labels:
                    det_scores[select_inds] = np.ones_like(select_scores)

        rel_feat = self.visual_rep(blob_conv,
                                   rel_ret,
                                   device_id,
                                   use_relu=use_relu)

        if not self.training:
            sbj_labels = det_labels[rel_ret['sbj_inds']]
            obj_labels = det_labels[rel_ret['obj_inds']]
            rel_ret['sbj_labels'] = det_labels[rel_ret['sbj_inds']]
            rel_ret['obj_labels'] = det_labels[rel_ret['obj_inds']]
            rel_ret['sbj_scores'] = det_scores[rel_ret['sbj_inds']]
            rel_ret['obj_scores'] = det_scores[rel_ret['obj_inds']]
        else:
            sbj_labels = rel_ret['all_sbj_labels_int32'] + 1
            obj_labels = rel_ret['all_obj_labels_int32'] + 1

        sbj_embed = self.ori_embed[sbj_labels].clone().cuda(device_id)
        obj_embed = self.ori_embed[obj_labels].clone().cuda(device_id)
        sbj_pos = torch.from_numpy(
            self.get_obj_pos(rel_ret['sbj_rois'],
                             im_info)).float().cuda(device_id)
        obj_pos = torch.from_numpy(
            self.get_obj_pos(rel_ret['obj_rois'],
                             im_info)).float().cuda(device_id)

        prod = self.sbj_map(torch.cat(
            (sbj_feat, sbj_embed, sbj_pos), -1)) * self.obj_map(
                torch.cat((obj_feat, obj_embed, obj_pos), -1))

        prd_scores = self.rel_compress(rel_feat * prod)

        if cfg.MODEL.USE_FREQ_BIAS:

            sbj_labels = torch.from_numpy(sbj_labels).long().cuda(device_id)
            obj_labels = torch.from_numpy(obj_labels).long().cuda(device_id)

            prd_bias_scores = self.freq_bias.rel_index_with_labels(
                torch.stack((sbj_labels - 1, obj_labels - 1), 1))

            prd_scores += prd_bias_scores

        if not self.training:
            prd_scores = F.softmax(prd_scores, -1)

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}

            imp_gamma = get_importance_factor(select_rois, sbj_gt_boxes,
                                              obj_gt_boxes, im_info)
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox
            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'],
                rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            loss_cls_prd, accuracy_cls_prd = reldn_heads.reldn_losses(
                prd_scores, rel_ret['all_prd_labels_int32'])
            return_dict['losses']['loss_cls_prd'] = loss_cls_prd
            return_dict['metrics']['accuracy_cls_prd'] = accuracy_cls_prd

            loss_cls_entity, accuracy_cls_entity = refine_obj_feats.entity_losses_imp(
                entity_cls_score, select_gt_labels, imp_gamma)
            return_dict['losses']['loss_cls_entity'] = loss_cls_entity
            return_dict['metrics']['accuracy_cls_entity'] = accuracy_cls_entity

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)
        else:
            # Testing
            return_dict['sbj_rois'] = rel_ret['sbj_rois']
            return_dict['obj_rois'] = rel_ret['obj_rois']
            return_dict['sbj_labels'] = rel_ret['sbj_labels']
            return_dict['obj_labels'] = rel_ret['obj_labels']
            return_dict['sbj_scores'] = rel_ret['sbj_scores']
            return_dict['obj_scores'] = rel_ret['obj_scores']
            return_dict['prd_scores'] = prd_scores

            if do_vis:
                return_dict['blob_conv'] = blob_conv

        return return_dict
예제 #22
0
    def forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)

        rois = Variable(torch.from_numpy(rpn_ret['rois'])).cuda(device_id)
        return_dict['rois'] = rois
        if self.training:
            return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            # rpn loss
            rpn_kwargs.update(dict(
                (k, rpn_ret[k]) for k in rpn_ret.keys()
                if (k.startswith('rpn_cls_logits') or k.startswith('rpn_bbox_pred'))
            ))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(**rpn_kwargs)
            return_dict['loss_rpn_cls'] = loss_rpn_cls
            return_dict['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'],
                rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights'])
            return_dict['loss_rcnn_cls'] = loss_cls
            return_dict['loss_rcnn_bbox'] = loss_bbox

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(res5_feat, rpn_ret,
                                               roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(mask_pred, rpn_ret['masks_int32'])
                return_dict['loss_rcnn_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(res5_feat, rpn_ret,
                                                  roi_has_keypoints_int32=rpn_ret['roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['loss_rcnn_keypoints'] = loss_keypoints

        return return_dict
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        return_dict = {}  # A dict to collect return variables
        if cfg.FPN.NON_LOCAL:
            blob_conv, f_div_C = self.Conv_Body(im_data)
            if cfg.MODEL.NON_LOCAL_TEST:
                return_dict['f_div_C'] = f_div_C
        else:
            blob_conv = self.Conv_Body(im_data)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'],
                rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            # we only use the car cls
            car_cls_int = 4
            if cfg.MODEL.CAR_CLS_HEAD_ON:
                if getattr(self.car_cls_Head, 'SHARE_RES5', False):
                    # TODO: add thos shared_res5 module
                    pass
                else:
                    car_cls_rot_feat = self.car_cls_Head(blob_conv, rpn_ret)
                    car_cls_score, car_cls, rot_pred = self.car_cls_Outs(
                        car_cls_rot_feat)
                    # car classification loss, we only fine tune the labelled cars

                # we only use the car cls
                car_idx = np.where(rpn_ret['labels_int32'] == car_cls_int)
                if len(cfg.TRAIN.CE_CAR_CLS_FINETUNE_WIGHT):
                    ce_weight = np.array(cfg.TRAIN.CE_CAR_CLS_FINETUNE_WIGHT)
                else:
                    ce_weight = None

                loss_car_cls, loss_rot, accuracy_car_cls = car_3d_pose_heads.fast_rcnn_car_cls_rot_losses(
                    car_cls_score[car_idx],
                    rot_pred[car_idx],
                    car_cls[car_idx],
                    rpn_ret['car_cls_labels_int32'][car_idx],
                    rpn_ret['quaternions'][car_idx],
                    ce_weight,
                    shape_sim_mat=self.shape_sim_mat)

                return_dict['losses']['loss_car_cls'] = loss_car_cls
                return_dict['losses']['loss_rot'] = loss_rot
                return_dict['metrics']['accuracy_car_cls'] = accuracy_car_cls
                return_dict['metrics']['shape_sim'] = shape_sim(
                    car_cls[car_idx].data.cpu().numpy(), self.shape_sim_mat,
                    rpn_ret['car_cls_labels_int32'][car_idx].astype('int64'))
                return_dict['metrics']['rot_diff_degree'] = rot_sim(
                    rot_pred[car_idx].data.cpu().numpy(),
                    rpn_ret['quaternions'][car_idx])

            if cfg.MODEL.TRANS_HEAD_ON:
                pred_boxes = car_3d_pose_heads.bbox_transform_pytorch(
                    rpn_ret['rois'], bbox_pred, im_info,
                    cfg.MODEL.BBOX_REG_WEIGHTS)
                car_idx = np.where(rpn_ret['labels_int32'] == car_cls_int)

                # Build translation head heres from the bounding box
                if cfg.TRANS_HEAD.INPUT_CONV_BODY:
                    pred_boxes_car = pred_boxes[:, 4 * car_cls_int:4 *
                                                (car_cls_int + 1)].squeeze(
                                                    dim=0)
                    car_trans_feat = self.car_trans_Head(
                        blob_conv, rpn_ret, pred_boxes_car)
                    car_trans_pred = self.car_trans_Outs(car_trans_feat)
                    car_trans_pred = car_trans_pred[car_idx]
                elif cfg.TRANS_HEAD.INPUT_TRIPLE_HEAD:
                    pred_boxes_car = pred_boxes[:, 4 * car_cls_int:4 *
                                                (car_cls_int + 1)].squeeze(
                                                    dim=0)
                    car_trans_feat = self.car_trans_Head(pred_boxes_car)
                    car_trans_pred = self.car_trans_Outs(
                        car_trans_feat, car_cls_rot_feat)
                    car_trans_pred = car_trans_pred[car_idx]
                else:
                    pred_boxes_car = pred_boxes[car_idx, 4 * car_cls_int:4 *
                                                (car_cls_int + 1)].squeeze(
                                                    dim=0)
                    car_trans_feat = self.car_trans_Head(pred_boxes_car)
                    car_trans_pred = self.car_trans_Outs(car_trans_feat)

                label_trans = rpn_ret['car_trans'][car_idx]
                loss_trans = car_3d_pose_heads.car_trans_losses(
                    car_trans_pred, label_trans)
                return_dict['losses']['loss_trans'] = loss_trans
                return_dict['metrics']['trans_diff_meter'], return_dict['metrics']['trans_thresh_per'] = \
                    trans_sim(car_trans_pred.data.cpu().numpy(), rpn_ret['car_trans'][car_idx],
                              cfg.TRANS_HEAD.TRANS_MEAN, cfg.TRANS_HEAD.TRANS_STD)

            # A 3D to 2D projection loss
            if cfg.MODEL.LOSS_3D_2D_ON:
                # During the mesh generation, using GT(True) or predicted(False) Car ID
                if cfg.LOSS_3D_2D.MESH_GEN_USING_GT:
                    # Acquire car id
                    car_ids = rpn_ret['car_cls_labels_int32'][car_idx].astype(
                        'int64')
                else:
                    # Using the predicted car id
                    print("Not properly implemented for pytorch")
                    car_ids = car_cls_score[car_idx].max(dim=1)
                # Get mesh vertices and generate loss
                UV_projection_loss = plane_projection_loss(
                    car_trans_pred, label_trans, rot_pred[car_idx],
                    rpn_ret['quaternions'][car_idx], car_ids, im_info,
                    self.car_models, self.intrinsic_mat, self.car_names)

                return_dict['losses'][
                    'UV_projection_loss'] = UV_projection_loss

            if cfg.MODEL.MASK_TRAIN_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                    mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_keypoints_int32=rpn_ret[
                            'roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                if type(v) == np.float64:
                    return_dict['metrics'][k] = v
                else:
                    return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        return return_dict
예제 #24
0
    def _forward(self, data, im_info, do_vis=False, dataset_name=None, roidb=None, use_gt_labels=False, **rpn_kwargs):
        im_data = data
        if self.training:
            # if not isinstance(roidb[0], np.array):
            #     roidb = roidb[0]
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb)) # only support one gpu
        if dataset_name is not None:
            dataset_name = blob_utils.deserialize(dataset_name)
        else:
            dataset_name = cfg.TRAIN.DATASETS[0] if self.training else cfg.TEST.DATASETS[0]  # assuming only one dataset per run

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)
        # if not cfg.MODEL.USE_REL_PYRAMID:
        #     blob_conv_prd = self.Prd_RCNN.Conv_Body(im_data)

        if self.training:
            gt_rois = np.empty((0, 5), dtype=np.float32)
            gt_classes = np.empty((0), dtype=np.int64)
            for i, r in enumerate(roidb):
                rois_i = r['boxes'] * im_info[i, 2]
                rois_i = np.hstack((i * blob_utils.ones((rois_i.shape[0], 1)), rois_i))
                gt_rois = np.append(gt_rois, rois_i, axis=0)
                gt_classes = np.append(gt_classes, r['gt_classes'], axis=0)

        if self.training or roidb is None:
            rpn_ret = self.RPN(blob_conv, im_info, roidb)




        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]
            # if not cfg.MODEL.USE_REL_PYRAMID:
            #     blob_conv_prd = blob_conv_prd[-self.num_roi_levels:]
            # else:
            #     blob_conv_prd = self.RelPyramid(blob_conv)

        if self.training or roidb is None:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret, use_relu=True)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret, use_relu=True)
            cls_score, bbox_pred = self.Box_Outs(box_feat)

        
        # now go through the predicate branch
        use_relu = False if cfg.MODEL.NO_FC7_RELU else True
        if self.training:
            score_thresh = cfg.TEST.SCORE_THRESH
            cls_score = F.softmax(cls_score, -1)
            while score_thresh >= -1e-06:  # a negative value very close to 0.0
                det_rois, det_labels, det_scores, det_dists, det_boxes_all = \
                    self.prepare_det_rois(rpn_ret['rois'], cls_score, bbox_pred, im_info, score_thresh)
                real_area = (det_rois[:, 3] - det_rois[:, 1]) * (det_rois[:, 4] - det_rois[:, 2])
                non_zero_area_inds = np.where(real_area > 0)[0]
                det_rois = det_rois[non_zero_area_inds]
                det_labels = det_labels[non_zero_area_inds]
                det_scores = det_scores[non_zero_area_inds]
                det_dists = det_dists[non_zero_area_inds]
                det_boxes_all = det_boxes_all[non_zero_area_inds]
                # rel_ret = self.RelPN(det_rois, det_labels, det_scores, im_info, dataset_name, roidb)
                valid_len = len(det_rois)
                if valid_len > 0:
                    break
                logger.info('Got {} det_rois when score_thresh={}, changing to {}'.format(
                    valid_len, score_thresh, score_thresh - 0.01))
                score_thresh -= 0.01
            det_labels_gt = []
            ious = box_utils.bbox_overlaps(det_rois[:, 1:], gt_rois[:, 1:]) * \
                                          (det_rois[:, 0][:,None] == gt_rois[:, 0][None, :])
            det_labels_gt = gt_classes[ious.argmax(-1)]
            det_labels_gt[ious.max(-1) < cfg.TRAIN.FG_THRESH] = 0

        else:
            if roidb is not None:
                # raise FError('not support this mode!')
                # assert len(roidb) == 1
                im_scale = im_info.data.numpy()[:, 2][0]
                im_w = im_info.data.numpy()[:, 1][0]
                im_h = im_info.data.numpy()[:, 0][0]
                
                fpn_ret = {'gt_rois': gt_rois}
                if cfg.FPN.FPN_ON and cfg.FPN.MULTILEVEL_ROIS:
                    lvl_min = cfg.FPN.ROI_MIN_LEVEL
                    lvl_max = cfg.FPN.ROI_MAX_LEVEL
                    rois_blob_names = ['gt_rois']
                    for rois_blob_name in rois_blob_names:
                        # Add per FPN level roi blobs named like: <rois_blob_name>_fpn<lvl>
                        target_lvls = fpn_utils.map_rois_to_fpn_levels(
                            fpn_ret[rois_blob_name][:, 1:5], lvl_min, lvl_max)
                        fpn_utils.add_multilevel_roi_blobs(
                            fpn_ret, rois_blob_name, fpn_ret[rois_blob_name], target_lvls,
                            lvl_min, lvl_max)
                det_feats = self.Box_Head(blob_conv, fpn_ret, rois_name='det_rois', use_relu=True)
                det_dists, _ = self.Box_Outs(det_feats)
                det_boxes_all = None
                if use_gt_labels:
                    det_labels_gt = gt_classes
                    det_labels = gt_classes
            else:

                score_thresh = cfg.TEST.SCORE_THRESH
                while score_thresh >= -1e-06:  # a negative value very close to 0.0
                    det_rois, det_labels, det_scores, det_dists, det_boxes_all = \
                        self.prepare_det_rois(rpn_ret['rois'], cls_score, bbox_pred, im_info, score_thresh)
                    real_area = (det_rois[:, 3] - det_rois[:, 1]) * (det_rois[:, 4] - det_rois[:, 2])
                    non_zero_area_inds = np.where(real_area > 0)[0]
                    det_rois = det_rois[non_zero_area_inds]
                    det_labels = det_labels[non_zero_area_inds]
                    det_scores = det_scores[non_zero_area_inds]
                    det_dists = det_dists[non_zero_area_inds]
                    det_boxes_all = det_boxes_all[non_zero_area_inds]
                    # rel_ret = self.RelPN(det_rois, det_labels, det_scores, im_info, dataset_name, roidb)
                    valid_len = len(det_rois)
                    if valid_len > 0:
                        break
                    logger.info('Got {} det_rois when score_thresh={}, changing to {}'.format(
                        valid_len, score_thresh, score_thresh - 0.01))
                    score_thresh -= 0.01 


        return_dict['det_rois'] = det_rois
        num_rois = det_rois.shape[0]
        if not isinstance(det_dists, torch.Tensor):
            assert det_dists.shape[0] == num_rois
            det_dists = torch.from_numpy(det_dists).float().cuda(device_id)
        
        return_dict['det_dists'] = det_dists
        return_dict['det_scores'] = det_scores
        return_dict['blob_conv'] = blob_conv
        return_dict['det_boxes_all'] = det_boxes_all
        assert det_boxes_all.shape[0] == num_rois
        return_dict['det_labels'] = det_labels
        # return_dict['blob_conv_prd'] = blob_conv_prd

        if self.training or use_gt_labels:
            return_dict['det_labels_gt'] = det_labels_gt

        return return_dict
예제 #25
0
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        M = cfg.LESION.NUM_IMAGES_3DCE
        # data.shape: [n,M*3,h,w]
        n, c, h, w = data.shape
        im_data = data.view(n * M, 3, h, w)
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()
        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)

        # blob.shape [n, c,h,w] for 2d
        # blob.shape [nM,c,h,w] for 3DCE

        blob_conv_for_RPN = []
        blob_conv_for_RCNN = []
        # 12/25,concat all slices before RPN.
        if cfg.LESION.CONCAT_BEFORE_RPN:
            for blob in blob_conv:
                _, c, h, w = blob.shape
                blob = blob.view(n, M * c, h, w)
                blob_conv_for_RPN.append(blob)
            blob_conv_for_RCNN = blob_conv_for_RPN
        # 01/20,ele-sum all slices before RPN.
        elif cfg.LESION.SUM_BEFORE_RPN:
            for blob in blob_conv:
                blob_ = 0
                for i in range(M):
                    blob_ += blob[n * i:n * (i + 1), :, :, :]
                blob_conv_for_RPN.append(blob_)
                _, c, h, w = blob.shape
                blob = blob.view(n, M * c, h, w)
                blob_conv_for_RCNN.append(blob)
        # Only support three_slices each modality currently.
        elif cfg.LESION.MULTI_MODALITY:
            for blob in blob_conv:
                _, c, h, w = blob.shape
                blob = blob.view(n, M * c, h, w)
                m1_blob_conv = blob[:, 0:c, :, :]
                m2_blob_conv = blob[:, c:, :, :]
                blob_conv_for_RPN.append(
                    self.gif_net(m1_blob_conv, m2_blob_conv))
            blob_conv_for_RCNN = blob_conv_for_RPN
        # Standard 3DCE, feed middle slice into RPN.
        else:
            for blob in blob_conv:
                _, c, h, w = blob.shape
                blob = blob.view(n, M * c, h, w)
                blob_conv_for_RPN.append(blob[:, (M // 2) * c:(M // 2 + 1) *
                                              c, :, :])
                blob_conv_for_RCNN.append(blob)

        rpn_ret = self.RPN(blob_conv_for_RPN, im_info, roidb)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv_for_RCNN = blob_conv_for_RCNN[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv_for_RCNN

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv_for_RCNN,
                                                    rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv_for_RCNN, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
            # print cls_score.shape
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            if cfg.MODEL.FASTER_RCNN:
                # bbox loss
                loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                    cls_score, bbox_pred, rpn_ret['labels_int32'],
                    rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                    rpn_ret['bbox_outside_weights'])
                return_dict['losses']['loss_cls'] = loss_cls
                return_dict['losses']['loss_bbox'] = loss_bbox
                return_dict['metrics']['accuracy_cls'] = accuracy_cls

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred
        return return_dict
예제 #26
0
    def forward(self, data, im_info=None, roidb=None, only_body=None, **rpn_kwargs):
        im_data = data
        if self.training and cfg.RPN.RPN_ON:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)

        if cfg.RPN.RPN_ON:
            rpn_ret = self.RPN(blob_conv, im_info, roidb)
        else:
            rpn_ret = rpn_kwargs

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if only_body is not None:
            return return_dict

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            if cfg.RPN.RPN_ON:
                rpn_kwargs.update(dict(
                    (k, rpn_ret[k]) for k in rpn_ret.keys()
                    if (k.startswith('rpn_cls_logits') or k.startswith('rpn_bbox_pred'))
                ))
                loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(**rpn_kwargs)
                if cfg.FPN.FPN_ON:
                    for i, lvl in enumerate(range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1)):
                        return_dict['losses']['loss_rpn_cls_fpn%d' % lvl] = loss_rpn_cls[i]
                        return_dict['losses']['loss_rpn_bbox_fpn%d' % lvl] = loss_rpn_bbox[i]
                else:
                    return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                    return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'],
                rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(res5_feat, rpn_ret,
                                               roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                if cfg.MRCNN.FUSION:
                    mask_pred = mask_feat
                elif cfg.MODEL.BOUNDARY_ON:
                    mask_pred = mask_feat[0]
                    boundary_pred = mask_feat[1]
                else:
                    mask_pred = mask_feat
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask
                if cfg.MODEL.BOUNDARY_ON:
                    loss_boundary = mask_rcnn_heads.mask_rcnn_losses_balanced(boundary_pred, rpn_ret['boundary_int32'])
                    return_dict['losses']['loss_boundary'] = loss_boundary

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(res5_feat, rpn_ret,
                                                  roi_has_keypoints_int32=rpn_ret['roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'], rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints
        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        return return_dict
예제 #27
0
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        # data.shape: [n,c,cfg.LESION.NUM_SLICE,h,w]
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()
        return_dict = {}  # A dict to collect return variables
        # blob_conv[i].shape [n,c,d,h,w] for RPN, d = cfg.LESION.SLICE_NUM
        blob_conv_for_RPN,blob_conv_for_RCNN = self.Conv_Body(im_data)

        # blob.shape [n,c,h,w] for RPN
        # blob.shape [n,c,h,w] for RCNN
        #blob_conv_for_RPN = []
        #blob_conv_for_RCNN = []
        #if RPN_2_RCNN_3:
        #    for blob in blob_conv:
        #        n, c, d, h, w = blob.shape
        #        blob_ = blob[:, :, d//2, :, :]
        #        blob_conv_for_RPN.append(blob_)
        #    blob_conv_for_RCNN = blob_conv


        rpn_ret = self.RPN(blob_conv_for_RPN, im_info, roidb)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv_for_RCNN = blob_conv_for_RCNN[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv_for_RCNN

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv_for_RCNN, rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv_for_RCNN, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
            # print cls_score.shape
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(dict(
                (k, rpn_ret[k]) for k in rpn_ret.keys()
                if (k.startswith('rpn_cls_logits') or k.startswith('rpn_bbox_pred'))
            ))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(**rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' % lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' % lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            if cfg.MODEL.FASTER_RCNN:
            # bbox loss
                loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                    cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'],
                    rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights'])
                return_dict['losses']['loss_cls'] = loss_cls
                return_dict['losses']['loss_bbox'] = loss_bbox
                return_dict['metrics']['accuracy_cls'] = accuracy_cls


            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred
        return return_dict
def main():
    """Main function"""

    args = parse_args()
    print('Called with args:')
    print(args)

    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    log_path = os.path.join(output_dir, 'train.log')
    logger = setup_logging_to_file(log_path)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train', )
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "common":
        cfg.TRAIN.DATASETS = ('common_train', )
        cfg.MODEL.NUM_CLASSES = 14
    elif args.dataset == "kitti":
        cfg.TRAIN.DATASETS = ('kitti_train', )
        cfg.MODEL.NUM_CLASSES = 8
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))

    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    cfg.NUM_GPUS = torch.cuda.device_count()
    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' %
          (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' %
          (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' %
          (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' %
          (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(
        map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print(
        'Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
        '    SOLVER.STEPS: {} --> {}\n'
        '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps,
                                                cfg.SOLVER.STEPS, old_max_iter,
                                                cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print(
            'Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
            '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(
                cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
    # pdb.set_trace()
    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb',
                timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = BatchSampler(sampler=MinibatchSampler(
        ratio_list, ratio_index),
                                batch_size=args.batch_size,
                                drop_last=True)
    dataset = RoiDataLoader(roidb, cfg.MODEL.NUM_CLASSES, training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_params = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    for key, value in dict(maskRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'gn' in key:
                gn_params.append(value)
            elif 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [{
        'params': nonbias_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY
    }, {
        'params':
        bias_params,
        'lr':
        0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
        'weight_decay':
        cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0
    }, {
        'params': gn_params,
        'lr': 0,
        'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN
    }]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    ### Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name,
                                map_location=lambda storage, loc: storage)
        net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print(
                        'train_size value: %d different from the one in checkpoint: %d'
                        % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            # optimizer.load_state_dict(checkpoint['optimizer'])
            misc_utils.load_optimizer_state_dict(optimizer,
                                                 checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0][
        'lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN,
                                 cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args, args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 -
                                                                 alpha) + alpha
                else:
                    raise KeyError(
                        'Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr,
                                               cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)
                # pdb.set_trace()

                for key in input_data:
                    if key != 'roidb':  # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                roidb = list(
                    map(lambda x: blob_utils.deserialize(x)[0],
                        input_data['roidb'][0]))
                # pdb.set_trace()

                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()
            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr, logger)

            if (step + 1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN,
                          optimizer, logger)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer,
                  logger)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer,
                  logger)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()
예제 #29
0
    def _forward(self,
                 data,
                 im_info,
                 roidb=None,
                 conv_body_only=False,
                 **rpn_kwargs):
        im_data = data
        # pdb.set_trace()

        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))
        # pdb.set_trace()

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)
        rpn_ret = self.RPN(blob_conv, im_info, roidb)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if conv_body_only:
            return blob_conv

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            if not cfg.MODEL.INMODAL_ON:
                cls_score, bbox_pred = self.Box_Outs(box_feat)
            else:
                cls_score, amodal_score, bbox_pred = self.Box_Outs(box_feat)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_amodal, loss_bbox, accuracy_cls, accuracy_amodal = fast_rcnn_heads.fast_rcnn_amodal_losses(
                cls_score, amodal_score, bbox_pred, rpn_ret['labels_int32'],
                rpn_ret['labels_amodal_int32'], rpn_ret['bbox_targets'],
                rpn_ret['bbox_inside_weights'],
                rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = loss_cls
            # return_dict['losses']['loss_amodal'] = loss_amodal
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls
            # return_dict['metrics']['accuracy_amodal'] = accuracy_amodal

            if cfg.MODEL.MASK_ON:
                if getattr(self.Amodal_Mask_Head, 'SHARE_RES5', False):
                    amodal_mask_feat = self.Amodal_Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    amodal_mask_feat = self.Amodal_Mask_Head(
                        blob_conv, rpn_ret)
                amodal_mask_pred = self.Amodal_Mask_Outs(amodal_mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask_amodal = mask_rcnn_heads.mask_rcnn_losses(
                    amodal_mask_pred, rpn_ret['masks_amodal_int32'])
                return_dict['losses']['loss_mask_amodal'] = loss_mask_amodal

            if cfg.MODEL.INMODAL_ON:
                if getattr(self.Inmodal_Mask_Head, 'SHARE_RES5', False):
                    inmodal_mask_feat = self.Inmodal_Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    inmodal_mask_feat = self.Inmodal_Mask_Head(
                        blob_conv, rpn_ret)
                inmodal_mask_pred = self.Inmodal_Mask_Outs(inmodal_mask_feat)
                # print(amodal_pred.shape)
                # print(rpn_ret['masks_amodal_int32'].shape)
                # input()
                loss_mask_inmodal = mask_rcnn_heads.mask_rcnn_losses(
                    inmodal_mask_pred, rpn_ret['masks_inmodal_int32'])
                # return_dict['losses']['loss_amodal'] = cfg.MODEL.AMODAL_WEIGHT * loss_amodal
                return_dict['losses']['loss_mask_inmodal'] = loss_mask_inmodal

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            if cfg.MODEL.INMODAL_ON:
                return_dict['amodal_score'] = amodal_score
            return_dict['bbox_pred'] = bbox_pred

        return return_dict
예제 #30
0
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        M = cfg.LESION.NUM_IMAGES_3DCE
        # data.shape: [n,M*3,h,w]
        n, c, h, w = data.shape
        im_data = data.view(n * M, 3, h, w)
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()
        return_dict = {}  # A dict to collect return variables
        if cfg.LESION.USE_POSITION:
            blob_conv, res5_feat = self.Conv_Body(im_data)
        else:
            blob_conv = self.Conv_Body(im_data)

        # blob.shape = [n, c,h,w] for 2d
        # blob.shape = [nM,c,h,w] for 3DCE or multi-modal

        blob_conv_for_RPN = []
        blob_conv_for_RCNN = []
        # Used for MVP-Net, concat all slices before RPN.
        if cfg.LESION.CONCAT_BEFORE_RPN:
            for blob in blob_conv:
                _, c, h, w = blob.shape
                blob = blob.view(n, M * c, h, w)
                blob_cbam = self.cbam(blob)
                blob_conv_for_RPN.append(blob_cbam)
            blob_conv_for_RCNN = blob_conv_for_RPN
        # Used for Standard 3DCE, feed middle slice into RPN.
        else:
            for blob in blob_conv:
                _, c, h, w = blob.shape
                blob = blob.view(n, M * c, h, w)
                blob_conv_for_RPN.append(blob[:, (M // 2) * c:(M // 2 + 1) *
                                              c, :, :])
                blob_conv_for_RCNN.append(blob)

        rpn_ret = self.RPN(blob_conv_for_RPN, im_info, roidb)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv_for_RCNN = blob_conv_for_RCNN[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv_for_RCNN

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv_for_RCNN,
                                                    rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv_for_RCNN, rpn_ret)

            if cfg.LESION.USE_POSITION:
                if cfg.LESION.NUM_IMAGES_3DCE == 3:
                    position_feat = blob_conv_for_RPN[0]
                    n, c, h, w = position_feat.shape
                    position_feat = position_feat.view(3, 256, h, w)
                    #print(position_feat.shape)
                elif cfg.LESION.NUM_IMAGES_3DCE == 9:
                    position_feat = blob_conv_for_RPN[0][:,
                                                         3 * 256:6 * 256, :, :]
                    n, c, h, w = position_feat.shape
                    position_feat = position_feat.view(3, 256, h, w)
                    #print(position_feat.shape)
                position_feat = self.Position_Head(position_feat)
                pos_cls_pred = self.Position_Cls_Outs(position_feat)
                pos_reg_pred = self.Position_Reg_Outs(position_feat)
                return_dict['pos_cls_pred'] = pos_cls_pred
                return_dict['pos_reg_pred'] = pos_reg_pred
            if cfg.LESION.POS_CONCAT_RCNN:
                try:
                    if self.training:
                        # expand position_feat as box_feat for concatenation.
                        # box_feat: (n*num_nms, 1024)
                        # position_feat: (n, 1024) expand--> position_feat_:(num_nms, n ,1024)
                        position_feat_ = position_feat.expand(
                            cfg.TRAIN.RPN_BATCH_SIZE_PER_IM, -1, -1)
                        # position_feat_: (num_nms ,n , 1024)  transpose--> (n, num_nms, 1024)  view-->(n*num_nms, 1024)
                        position_feat_concat = torch.transpose(
                            position_feat_, 1, 0).contiguous().view(-1, 1024)
                    else:
                        # box_feat: (1, 1024)
                        # position_feat: (1, 1024), position_feat_:(1000 ,1024)
                        position_feat_ = position_feat.expand(
                            cfg.TEST.RPN_PRE_NMS_TOP_N, -1)
                        # position_feat_: (1, 1024)  transpose--> position_feat_concat (n, num_nms, 1024)
                        position_feat_concat = torch.transpose(
                            position_feat_, 1,
                            0).contiguous().view_as(box_feat)
                except:
                    import pdb
                    pdb.set_trace()
                    print('position_feat', position_feat.shape)
                    print('box_feat', box_feat.shape)
                    print('position_feat_', position_feat_.shape)
                    print('position_feat_concat', position_feat_concat.shape)
                else:
                    box_feat = torch.cat([box_feat, position_feat_concat], 1)
                    cls_score, bbox_pred = self.Box_Outs(box_feat)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
            # print cls_score.shape
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            if cfg.MODEL.FASTER_RCNN:  # bbox loss
                loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                    cls_score, bbox_pred, rpn_ret['labels_int32'],
                    rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                    rpn_ret['bbox_outside_weights'])
                return_dict['losses']['loss_cls'] = loss_cls
                return_dict['losses']['loss_bbox'] = loss_bbox
                return_dict['metrics']['accuracy_cls'] = accuracy_cls

            if cfg.LESION.USE_POSITION:
                pos_cls_loss, pos_reg_loss, accuracy_position = position_losses(
                    pos_cls_pred, pos_reg_pred, roidb)
                return_dict['losses']['pos_cls_loss'] = pos_cls_loss
                return_dict['losses']['pos_reg_loss'] = pos_reg_loss
                return_dict['metrics']['accuracy_position'] = accuracy_position

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred
        return return_dict
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)  # (2, 1024, 50, 50), conv4

        rpn_ret = self.RPN(blob_conv, im_info, roidb)  # (N*2000, 5)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(
                    blob_conv, rpn_ret)  # box_feat: (N*512, 2014, 1, 1)
            else:
                # pdb.set_trace()
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(
                box_feat)  # cls_score: (N*512, C)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss

            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))

            # pdb.set_trace()
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)

            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'],
                rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            # LJY: if the image has no mask annotations, then disable loss computing for them
            for idx, e in enumerate(roidb):
                has_mask = e['has_mask']
                if has_mask:
                    continue
                ind = rpn_ret['mask_rois'][:, 0] == idx
                rpn_ret['masks_int32'][ind, :] = np.zeros_like(
                    rpn_ret['masks_int32'][ind, :])

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                    mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_keypoints_int32=rpn_ret[
                            'roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        return return_dict
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)
        '''
        anchors of different scale and aspect ratios will be collected
        the function self.RPN is rpn_heads.generic_rpn_outputs
        it will call FPN.fpn_rpn_outputs which is a class name, so it creates the class object
        '''
        rpn_ret = self.RPN(blob_conv, im_info, roidb)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                '''in faster rcnn fpn r-50 case it is fast_rcnn_heads.roi_2mlp_head'''
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'],
                rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                    mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_keypoints_int32=rpn_ret[
                            'roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred

        return return_dict
예제 #33
0
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)

        # if self.training:
        #     # can be used to infer fg/bg ratio
        #     return_dict['rois_label'] = rpn_ret['labels_int32']

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.MODEL.RPN_ONLY:
            if cfg.MODEL.SHARE_RES5 and self.training:
                box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
            else:
                box_feat = self.Box_Head(blob_conv, rpn_ret)
            cls_score, bbox_pred = self.Box_Outs(box_feat)
        else:
            # TODO: complete the returns for RPN only situation
            pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)

            if cfg.RPN.VIS_QUANT_TARGET:
                import numpy as np
                import json
                import os
                import time

                gt_boxes = []
                gt_label = roidb[0]['gt_classes']
                for inds, item in enumerate(gt_label):
                    if item != 0:
                        gt_boxes.append(roidb[0]['boxes'][inds])

                gt_boxes = np.array(gt_boxes, dtype=np.float32)
                gt_boxes *= im_info.detach().numpy()[:, 2]

                path = "/nfs/project/libo_i/Boosting/Targets_Info"
                if not os.path.exists(path):
                    os.makedirs(path)
                b, c, h, w = rpn_kwargs['rpn_cls_logits_fpn3'].shape
                sample_targets = rpn_kwargs[
                    'rpn_bbox_targets_wide_fpn3'][:, :, :h, :w]

                line_targets = sample_targets.detach().data.cpu().numpy()
                with open(os.path.join(path, "quant_anchors.json"), "r") as fp:
                    quant_anchors = np.array(json.load(fp), dtype=np.float32)
                    quant_anchors = quant_anchors[:h, :w]

                line_targets = line_targets[:, 4:8, :, :].transpose(
                    (0, 2, 3, 1)).reshape(quant_anchors.shape)
                line_targets = line_targets.reshape(-1, 4)

                width = im_data.shape[3]
                height = im_data.shape[2]
                # 在这里加上targets的偏移

                line_quant_anchors = quant_anchors.reshape(-1, 4)
                pred_boxes = box_utils.onedim_bbox_transform(
                    line_quant_anchors, line_targets)
                pred_boxes = box_utils.clip_tiled_boxes(
                    pred_boxes, (height, width, 3))

                im = im_data.detach().cpu().numpy().reshape(3, height,
                                                            width).transpose(
                                                                (1, 2, 0))

                means = np.squeeze(cfg.PIXEL_MEANS)
                for i in range(3):
                    im[:, :, i] += means[i]

                im = im.astype(int)
                dpi = 200
                fig = plt.figure(frameon=False)
                fig.set_size_inches(im.shape[1] / dpi, im.shape[0] / dpi)
                ax = plt.Axes(fig, [0., 0., 1., 1.])
                ax.axis('off')
                fig.add_axes(ax)
                ax.imshow(im[:, :, ::-1])
                # 在im上添加gt
                for item in gt_boxes:
                    ax.add_patch(
                        plt.Rectangle((item[0], item[1]),
                                      item[2] - item[0],
                                      item[3] - item[1],
                                      fill=False,
                                      edgecolor='white',
                                      linewidth=1,
                                      alpha=1))

                cnt = 0
                for inds, before_item in enumerate(line_quant_anchors):
                    after_item = pred_boxes[inds]
                    targets_i = line_targets[inds]
                    if np.sum(targets_i) == 0:
                        continue
                    ax.add_patch(
                        plt.Rectangle((before_item[0], before_item[1]),
                                      before_item[2] - before_item[0],
                                      before_item[3] - before_item[1],
                                      fill=False,
                                      edgecolor='r',
                                      linewidth=1,
                                      alpha=1))

                    ax.add_patch(
                        plt.Rectangle((after_item[0], after_item[1]),
                                      after_item[2] - after_item[0],
                                      after_item[3] - after_item[1],
                                      fill=False,
                                      edgecolor='g',
                                      linewidth=1,
                                      alpha=1))

                    logger.info("valid boxes: {}".format(cnt))
                    cnt += 1

                if cnt != 0:
                    ticks = time.time()
                    fig.savefig(
                        "/nfs/project/libo_i/Boosting/Targets_Info/{}.png".
                        format(ticks),
                        dpi=dpi)

                plt.close('all')

            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'],
                rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                rpn_ret['bbox_outside_weights'])
            if cfg.RPN.ZEROLOSS:
                zero_loss_bbox = torch.Tensor([0.]).squeeze().cuda()
                zero_loss_bbox.requires_grad = True
                zero_loss_cls = torch.Tensor([0.]).squeeze().cuda()
                zero_loss_cls.requires_grad = True
                return_dict['losses']['loss_bbox'] = zero_loss_bbox
                return_dict['losses']['loss_cls'] = zero_loss_cls

            else:
                return_dict['losses']['loss_bbox'] = loss_bbox
                return_dict['losses']['loss_cls'] = loss_cls

            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            if cfg.MODEL.MASK_ON:
                if getattr(self.Mask_Head, 'SHARE_RES5', False):
                    mask_feat = self.Mask_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                else:
                    mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                mask_pred = self.Mask_Outs(mask_feat)
                # return_dict['mask_pred'] = mask_pred
                # mask loss
                loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                    mask_pred, rpn_ret['masks_int32'])
                return_dict['losses']['loss_mask'] = loss_mask

            if cfg.MODEL.KEYPOINTS_ON:
                if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                    # No corresponding keypoint head implemented yet (Neither in Detectron)
                    # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                    kps_feat = self.Keypoint_Head(
                        res5_feat,
                        rpn_ret,
                        roi_has_keypoints_int32=rpn_ret[
                            'roi_has_keypoint_int32'])
                else:
                    kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                kps_pred = self.Keypoint_Outs(kps_feat)
                # return_dict['keypoints_pred'] = kps_pred
                # keypoints loss
                if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'])
                else:
                    loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                        kps_pred, rpn_ret['keypoint_locations_int32'],
                        rpn_ret['keypoint_weights'],
                        rpn_ret['keypoint_loss_normalizer'])
                return_dict['losses']['loss_kps'] = loss_keypoints

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            # Testing
            return_dict['rois'] = rpn_ret['rois']
            return_dict['cls_score'] = cls_score
            return_dict['bbox_pred'] = bbox_pred
            if cfg.TEST.PROPOSALS_OUT:
                import os
                import json
                import numpy as np

                # 直接在这里做变换,输出经过变换之后的1000个框
                bbox_pred = bbox_pred.data.cpu().numpy().squeeze()
                box_deltas = bbox_pred.reshape([-1, bbox_pred.shape[-1]])
                shift_boxes = box_utils.bbox_transform(
                    rpn_ret['rois'][:, 1:5], box_deltas,
                    cfg.MODEL.BBOX_REG_WEIGHTS)
                shift_boxes = box_utils.clip_tiled_boxes(
                    shift_boxes,
                    im_info.data.cpu().numpy().squeeze()[0:2])

                num_classes = cfg.MODEL.NUM_CLASSES
                onecls_pred_boxes = []
                inds_all = []
                for j in range(1, num_classes):
                    inds = np.where(cls_score[:, j] > cfg.TEST.SCORE_THRESH)[0]
                    boxes_j = shift_boxes[inds, j * 4:(j + 1) * 4]
                    onecls_pred_boxes += boxes_j.tolist()
                    inds_all.extend(inds.tolist())

                inds_all = np.array(inds_all, dtype=np.int)
                aligned_proposals = rpn_ret['rois'][:, 1:5][inds_all]
                aligned_boxes = np.array(onecls_pred_boxes, dtype=np.float32)

                assert inds_all.shape[0] == aligned_boxes.shape[0]
                assert aligned_proposals.size == aligned_boxes.size

                path = "/nfs/project/libo_i/Boosting/Anchor_Info"
                with open(os.path.join(path, "proposals.json"), "w") as fp:
                    json.dump(aligned_proposals.tolist(), fp)

                with open(os.path.join(path, "boxes.json"), "w") as fp:
                    json.dump(aligned_boxes.tolist(), fp)

        return return_dict
예제 #34
0
    def _forward(self, data, im_info, roidb=None, **rpn_kwargs):
        im_data = data
        if self.training:
            # roidb: list, length = batch size
            # 'has_visible_keypoints': bool
            # 'boxes' & 'gt_classes': object bboxes and classes
            # 'segms', 'seg_areas', 'gt_overlaps', 'is_crowd', 'box_to_gt_ind_map': pass
            # 'gt_actions': num_box*26
            # 'gt_role_id': num_box*26*2
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)

        # Original RPN module will generate proposals, and sample 256 positive/negative
        # examples in a 1:3 ratio for r-cnn stage. For InteractNet(hoi), I set
        # cfg.TRAIN.BATCH_SIZE_PER_IM and cfg.TRAIN.FG_FRACTION big value, to save
        # every proposal in rpn_ret, then I will re-sample from rpn_ret for three branch
        # of InteractNet, see roi_data/hoi_data.py for more information.
        if not cfg.VCOCO.USE_PRECOMP_BOX:
            rpn_ret = self.RPN(blob_conv, im_info, roidb)
            if cfg.MODEL.VCOCO_ON and self.training:
                # WARNING! always sample hoi branch before detection branch when training
                hoi_blob_in = sample_for_hoi_branch(rpn_ret,
                                                    roidb,
                                                    im_info,
                                                    is_training=True)
                # Re-sampling for RCNN head, rpn_ret will be modified inplace
                sample_for_detection_branch(rpn_ret)
        elif self.training:
            json_dataset.add_proposals(roidb,
                                       rois=None,
                                       im_info=im_info.data.numpy(),
                                       crowd_thresh=0)  #[:, 2]
            hoi_blob_in = sample_for_hoi_branch_precomp_box_train(
                roidb, im_info, is_training=True)
            if hoi_blob_in is None:
                return_dict['losses'] = {}
                return_dict['metrics'] = {}
                return_dict['losses'][
                    'loss_hoi_interaction_action'] = torch.tensor(
                        [0.]).cuda(device_id)
                return_dict['metrics'][
                    'accuracy_interaction_cls'] = torch.tensor(
                        [0.]).cuda(device_id)
                return_dict['losses'][
                    'loss_hoi_interaction_affinity'] = torch.tensor(
                        [0.]).cuda(device_id)
                return_dict['metrics'][
                    'accuracy_interaction_affinity'] = torch.tensor(
                        [0.]).cuda(device_id)
                return return_dict

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            #blob_conv = blob_conv[-self.num_roi_levels:]
            if cfg.FPN.MULTILEVEL_ROIS:
                blob_conv = blob_conv[-self.num_roi_levels:]
            else:
                blob_conv = blob_conv[-1]

        if not self.training:
            return_dict['blob_conv'] = blob_conv

        if not cfg.VCOCO.USE_PRECOMP_BOX:
            if not cfg.MODEL.RPN_ONLY:
                if cfg.MODEL.SHARE_RES5 and self.training:
                    box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret)
                else:
                    box_feat = self.Box_Head(blob_conv, rpn_ret)
                cls_score, bbox_pred = self.Box_Outs(box_feat)
            else:
                # TODO: complete the returns for RPN only situation
                pass

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            if not cfg.VCOCO.USE_PRECOMP_BOX:
                rpn_kwargs.update(
                    dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                         if (k.startswith('rpn_cls_logits')
                             or k.startswith('rpn_bbox_pred'))))
                loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                    **rpn_kwargs)
                if cfg.FPN.FPN_ON:
                    for i, lvl in enumerate(
                            range(cfg.FPN.RPN_MIN_LEVEL,
                                  cfg.FPN.RPN_MAX_LEVEL + 1)):
                        return_dict['losses']['loss_rpn_cls_fpn%d' %
                                              lvl] = loss_rpn_cls[i]
                        return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                              lvl] = loss_rpn_bbox[i]
                else:
                    return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                    return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox

                # bbox loss
                loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                    cls_score, bbox_pred, rpn_ret['labels_int32'],
                    rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                    rpn_ret['bbox_outside_weights'])
                return_dict['losses']['loss_cls'] = loss_cls
                return_dict['losses']['loss_bbox'] = loss_bbox
                return_dict['metrics']['accuracy_cls'] = accuracy_cls

                if cfg.MODEL.MASK_ON:
                    if getattr(self.Mask_Head, 'SHARE_RES5', False):
                        mask_feat = self.Mask_Head(
                            res5_feat,
                            rpn_ret,
                            roi_has_mask_int32=rpn_ret['roi_has_mask_int32'])
                    else:
                        mask_feat = self.Mask_Head(blob_conv, rpn_ret)
                    mask_pred = self.Mask_Outs(mask_feat)
                    # return_dict['mask_pred'] = mask_pred
                    # mask loss
                    loss_mask = mask_rcnn_heads.mask_rcnn_losses(
                        mask_pred, rpn_ret['masks_int32'])
                    return_dict['losses']['loss_mask'] = loss_mask

                if cfg.MODEL.KEYPOINTS_ON:
                    if getattr(self.Keypoint_Head, 'SHARE_RES5', False):
                        # No corresponding keypoint head implemented yet (Neither in Detectron)
                        # Also, rpn need to generate the label 'roi_has_keypoints_int32'
                        kps_feat = self.Keypoint_Head(
                            res5_feat,
                            rpn_ret,
                            roi_has_keypoints_int32=rpn_ret[
                                'roi_has_keypoint_int32'])
                    else:
                        kps_feat = self.Keypoint_Head(blob_conv, rpn_ret)
                    kps_pred = self.Keypoint_Outs(kps_feat)
                    # return_dict['keypoints_pred'] = kps_pred
                    # keypoints loss
                    if cfg.KRCNN.NORMALIZE_BY_VISIBLE_KEYPOINTS:
                        loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                            kps_pred, rpn_ret['keypoint_locations_int32'],
                            rpn_ret['keypoint_weights'])
                    else:
                        loss_keypoints = keypoint_rcnn_heads.keypoint_losses(
                            kps_pred, rpn_ret['keypoint_locations_int32'],
                            rpn_ret['keypoint_weights'],
                            rpn_ret['keypoint_loss_normalizer'])
                    return_dict['losses']['loss_kps'] = loss_keypoints

            if cfg.MODEL.VCOCO_ON:
                hoi_blob_out = self.HOI_Head(blob_conv, hoi_blob_in)

                interaction_action_loss, interaction_affinity_loss, \
                interaction_action_accuray_cls, interaction_affinity_cls = self.HOI_Head.loss(
                    hoi_blob_out)

                return_dict['losses'][
                    'loss_hoi_interaction_action'] = interaction_action_loss
                return_dict['metrics'][
                    'accuracy_interaction_cls'] = interaction_action_accuray_cls
                return_dict['losses'][
                    'loss_hoi_interaction_affinity'] = interaction_affinity_loss
                return_dict['metrics'][
                    'accuracy_interaction_affinity'] = interaction_affinity_cls

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)

        else:
            if not cfg.VCOCO.USE_PRECOMP_BOX:
                return_dict['rois'] = rpn_ret['rois']
                return_dict['cls_score'] = cls_score
                return_dict['bbox_pred'] = bbox_pred

        #print('return ready')
        return return_dict