Пример #1
0
def _sample_pairs(roidb, im_scale, batch_idx):
    """Generate a random sample of RoIs comprising foreground and background
    examples.
    """
    fg_pairs_per_image = cfg.TRAIN.FG_REL_SIZE_PER_IM
    pairs_per_image = int(
        cfg.TRAIN.FG_REL_SIZE_PER_IM /
        cfg.TRAIN.FG_REL_FRACTION)  # need much more pairs since it's quadratic
    max_pair_overlaps = roidb['max_pair_overlaps']

    gt_pair_inds = np.where(max_pair_overlaps > 1.0 - 1e-4)[0]
    fg_pair_inds = np.where((max_pair_overlaps >= cfg.TRAIN.FG_THRESH)
                            & (max_pair_overlaps <= 1.0 - 1e-4))[0]

    fg_pairs_per_this_image = np.minimum(fg_pairs_per_image,
                                         gt_pair_inds.size + fg_pair_inds.size)
    # Sample foreground regions without replacement
    if fg_pair_inds.size > 0:
        fg_pair_inds = npr.choice(fg_pair_inds,
                                  size=(fg_pairs_per_this_image -
                                        gt_pair_inds.size),
                                  replace=False)
    fg_pair_inds = np.append(fg_pair_inds, gt_pair_inds)

    # Label is the class each RoI has max overlap with
    fg_prd_labels = roidb['max_prd_classes'][fg_pair_inds]
    blob_dict = dict(
        fg_prd_labels_int32=fg_prd_labels.astype(np.int32, copy=False))
    if cfg.MODEL.USE_BG:
        bg_pair_inds = np.where(
            (max_pair_overlaps < cfg.TRAIN.BG_THRESH_HI))[0]

        # Compute number of background RoIs to take from this image (guarding
        # against there being fewer than desired)
        bg_pairs_per_this_image = pairs_per_image - fg_pairs_per_this_image
        bg_pairs_per_this_image = np.minimum(bg_pairs_per_this_image,
                                             bg_pair_inds.size)
        # Sample foreground regions without replacement
        if bg_pair_inds.size > 0:
            bg_pair_inds = npr.choice(bg_pair_inds,
                                      size=bg_pairs_per_this_image,
                                      replace=False)
        keep_pair_inds = np.append(fg_pair_inds, bg_pair_inds)
        all_prd_labels = np.zeros(keep_pair_inds.size, dtype=np.int32)
        all_prd_labels[:fg_pair_inds.
                       size] = fg_prd_labels + 1  # class should start from 1
    else:
        keep_pair_inds = fg_pair_inds
        all_prd_labels = fg_prd_labels
    blob_dict['all_prd_labels_int32'] = all_prd_labels.astype(np.int32,
                                                              copy=False)
    blob_dict['fg_size'] = np.array(
        [fg_pair_inds.size], dtype=np.int32
    )  # this is used to check if there is at least one fg to learn

    sampled_sbj_boxes = roidb['sbj_boxes'][keep_pair_inds]
    sampled_obj_boxes = roidb['obj_boxes'][keep_pair_inds]
    # Scale rois and format as (batch_idx, x1, y1, x2, y2)
    sampled_sbj_rois = sampled_sbj_boxes * im_scale
    sampled_obj_rois = sampled_obj_boxes * im_scale
    repeated_batch_idx = batch_idx * blob_utils.ones(
        (keep_pair_inds.shape[0], 1))
    sampled_sbj_rois = np.hstack((repeated_batch_idx, sampled_sbj_rois))
    sampled_obj_rois = np.hstack((repeated_batch_idx, sampled_obj_rois))
    blob_dict['sbj_rois'] = sampled_sbj_rois
    blob_dict['obj_rois'] = sampled_obj_rois
    sampled_rel_rois = box_utils_rel.rois_union(sampled_sbj_rois,
                                                sampled_obj_rois)
    blob_dict['rel_rois'] = sampled_rel_rois
    if cfg.MODEL.USE_SPATIAL_FEAT:
        sampled_spt_feat = box_utils_rel.get_spt_features(
            sampled_sbj_boxes, sampled_obj_boxes, roidb['width'],
            roidb['height'])
        blob_dict['spt_feat'] = sampled_spt_feat
    if cfg.MODEL.USE_FREQ_BIAS:
        sbj_labels = roidb['max_sbj_classes'][keep_pair_inds]
        obj_labels = roidb['max_obj_classes'][keep_pair_inds]
        blob_dict['all_sbj_labels_int32'] = sbj_labels.astype(np.int32,
                                                              copy=False)
        blob_dict['all_obj_labels_int32'] = obj_labels.astype(np.int32,
                                                              copy=False)
    if cfg.MODEL.USE_NODE_CONTRASTIVE_LOSS or cfg.MODEL.USE_NODE_CONTRASTIVE_SO_AWARE_LOSS or cfg.MODEL.USE_NODE_CONTRASTIVE_P_AWARE_LOSS:
        nodes_per_image = cfg.MODEL.NODE_SAMPLE_SIZE
        max_sbj_overlaps = roidb['max_sbj_overlaps']
        max_obj_overlaps = roidb['max_obj_overlaps']
        # sbj
        # Here a naturally existing assumption is, each positive sbj should have at least one positive obj
        sbj_pos_pair_pos_inds = np.where(
            (max_pair_overlaps >= cfg.TRAIN.FG_THRESH))[0]
        sbj_pos_obj_pos_pair_neg_inds = np.where(
            (max_sbj_overlaps >= cfg.TRAIN.FG_THRESH)
            & (max_obj_overlaps >= cfg.TRAIN.FG_THRESH)
            & (max_pair_overlaps < cfg.TRAIN.BG_THRESH_HI))[0]
        sbj_pos_obj_neg_pair_neg_inds = np.where(
            (max_sbj_overlaps >= cfg.TRAIN.FG_THRESH)
            & (max_obj_overlaps < cfg.TRAIN.FG_THRESH)
            & (max_pair_overlaps < cfg.TRAIN.BG_THRESH_HI))[0]
        if sbj_pos_pair_pos_inds.size > 0:
            sbj_pos_pair_pos_inds = npr.choice(
                sbj_pos_pair_pos_inds,
                size=int(min(nodes_per_image, sbj_pos_pair_pos_inds.size)),
                replace=False)
        if sbj_pos_obj_pos_pair_neg_inds.size > 0:
            sbj_pos_obj_pos_pair_neg_inds = npr.choice(
                sbj_pos_obj_pos_pair_neg_inds,
                size=int(
                    min(nodes_per_image, sbj_pos_obj_pos_pair_neg_inds.size)),
                replace=False)
        sbj_pos_pair_neg_inds = sbj_pos_obj_pos_pair_neg_inds
        if nodes_per_image - sbj_pos_obj_pos_pair_neg_inds.size > 0 and sbj_pos_obj_neg_pair_neg_inds.size > 0:
            sbj_pos_obj_neg_pair_neg_inds = npr.choice(
                sbj_pos_obj_neg_pair_neg_inds,
                size=int(
                    min(nodes_per_image - sbj_pos_obj_pos_pair_neg_inds.size,
                        sbj_pos_obj_neg_pair_neg_inds.size)),
                replace=False)
            sbj_pos_pair_neg_inds = np.append(sbj_pos_pair_neg_inds,
                                              sbj_pos_obj_neg_pair_neg_inds)
        sbj_pos_inds = np.append(sbj_pos_pair_pos_inds, sbj_pos_pair_neg_inds)
        binary_labels_sbj_pos = np.zeros(sbj_pos_inds.size, dtype=np.int32)
        binary_labels_sbj_pos[:sbj_pos_pair_pos_inds.size] = 1
        blob_dict[
            'binary_labels_sbj_pos_int32'] = binary_labels_sbj_pos.astype(
                np.int32, copy=False)
        prd_pos_labels_sbj_pos = roidb['max_prd_classes'][
            sbj_pos_pair_pos_inds]
        prd_labels_sbj_pos = np.zeros(sbj_pos_inds.size, dtype=np.int32)
        prd_labels_sbj_pos[:sbj_pos_pair_pos_inds.
                           size] = prd_pos_labels_sbj_pos + 1
        blob_dict['prd_labels_sbj_pos_int32'] = prd_labels_sbj_pos.astype(
            np.int32, copy=False)
        sbj_labels_sbj_pos = roidb['max_sbj_classes'][sbj_pos_inds] + 1
        # 1. set all obj labels > 0
        obj_labels_sbj_pos = roidb['max_obj_classes'][sbj_pos_inds] + 1
        # 2. find those negative obj
        max_obj_overlaps_sbj_pos = roidb['max_obj_overlaps'][sbj_pos_inds]
        obj_neg_inds_sbj_pos = np.where(
            max_obj_overlaps_sbj_pos < cfg.TRAIN.FG_THRESH)[0]
        obj_labels_sbj_pos[obj_neg_inds_sbj_pos] = 0
        blob_dict['sbj_labels_sbj_pos_int32'] = sbj_labels_sbj_pos.astype(
            np.int32, copy=False)
        blob_dict['obj_labels_sbj_pos_int32'] = obj_labels_sbj_pos.astype(
            np.int32, copy=False)
        # this is for freq bias in RelDN
        blob_dict['sbj_labels_sbj_pos_fg_int32'] = roidb['max_sbj_classes'][
            sbj_pos_inds].astype(np.int32, copy=False)
        blob_dict['obj_labels_sbj_pos_fg_int32'] = roidb['max_obj_classes'][
            sbj_pos_inds].astype(np.int32, copy=False)

        sampled_sbj_boxes_sbj_pos = roidb['sbj_boxes'][sbj_pos_inds]
        sampled_obj_boxes_sbj_pos = roidb['obj_boxes'][sbj_pos_inds]
        # Scale rois and format as (batch_idx, x1, y1, x2, y2)
        sampled_sbj_rois_sbj_pos = sampled_sbj_boxes_sbj_pos * im_scale
        sampled_obj_rois_sbj_pos = sampled_obj_boxes_sbj_pos * im_scale
        repeated_batch_idx = batch_idx * blob_utils.ones(
            (sbj_pos_inds.shape[0], 1))
        sampled_sbj_rois_sbj_pos = np.hstack(
            (repeated_batch_idx, sampled_sbj_rois_sbj_pos))
        sampled_obj_rois_sbj_pos = np.hstack(
            (repeated_batch_idx, sampled_obj_rois_sbj_pos))
        blob_dict['sbj_rois_sbj_pos'] = sampled_sbj_rois_sbj_pos
        blob_dict['obj_rois_sbj_pos'] = sampled_obj_rois_sbj_pos
        sampled_rel_rois_sbj_pos = box_utils_rel.rois_union(
            sampled_sbj_rois_sbj_pos, sampled_obj_rois_sbj_pos)
        blob_dict['rel_rois_sbj_pos'] = sampled_rel_rois_sbj_pos
        _, inds_unique_sbj_pos, inds_reverse_sbj_pos = np.unique(
            sampled_sbj_rois_sbj_pos,
            return_index=True,
            return_inverse=True,
            axis=0)
        assert inds_reverse_sbj_pos.shape[0] == sampled_sbj_rois_sbj_pos.shape[
            0]
        blob_dict['inds_unique_sbj_pos'] = inds_unique_sbj_pos
        blob_dict['inds_reverse_sbj_pos'] = inds_reverse_sbj_pos
        if cfg.MODEL.USE_SPATIAL_FEAT:
            sampled_spt_feat_sbj_pos = box_utils_rel.get_spt_features(
                sampled_sbj_boxes_sbj_pos, sampled_obj_boxes_sbj_pos,
                roidb['width'], roidb['height'])
            blob_dict['spt_feat_sbj_pos'] = sampled_spt_feat_sbj_pos
        # obj
        # Here a naturally existing assumption is, each positive obj should have at least one positive sbj
        obj_pos_pair_pos_inds = np.where(
            (max_pair_overlaps >= cfg.TRAIN.FG_THRESH))[0]
        obj_pos_sbj_pos_pair_neg_inds = np.where(
            (max_obj_overlaps >= cfg.TRAIN.FG_THRESH)
            & (max_sbj_overlaps >= cfg.TRAIN.FG_THRESH)
            & (max_pair_overlaps < cfg.TRAIN.BG_THRESH_HI))[0]
        obj_pos_sbj_neg_pair_neg_inds = np.where(
            (max_obj_overlaps >= cfg.TRAIN.FG_THRESH)
            & (max_sbj_overlaps < cfg.TRAIN.FG_THRESH)
            & (max_pair_overlaps < cfg.TRAIN.BG_THRESH_HI))[0]
        if obj_pos_pair_pos_inds.size > 0:
            obj_pos_pair_pos_inds = npr.choice(
                obj_pos_pair_pos_inds,
                size=int(min(nodes_per_image, obj_pos_pair_pos_inds.size)),
                replace=False)
        if obj_pos_sbj_pos_pair_neg_inds.size > 0:
            obj_pos_sbj_pos_pair_neg_inds = npr.choice(
                obj_pos_sbj_pos_pair_neg_inds,
                size=int(
                    min(nodes_per_image, obj_pos_sbj_pos_pair_neg_inds.size)),
                replace=False)
        obj_pos_pair_neg_inds = obj_pos_sbj_pos_pair_neg_inds
        if nodes_per_image - obj_pos_sbj_pos_pair_neg_inds.size > 0 and obj_pos_sbj_neg_pair_neg_inds.size:
            obj_pos_sbj_neg_pair_neg_inds = npr.choice(
                obj_pos_sbj_neg_pair_neg_inds,
                size=int(
                    min(nodes_per_image - obj_pos_sbj_pos_pair_neg_inds.size,
                        obj_pos_sbj_neg_pair_neg_inds.size)),
                replace=False)
            obj_pos_pair_neg_inds = np.append(obj_pos_pair_neg_inds,
                                              obj_pos_sbj_neg_pair_neg_inds)
        obj_pos_inds = np.append(obj_pos_pair_pos_inds, obj_pos_pair_neg_inds)
        binary_labels_obj_pos = np.zeros(obj_pos_inds.size, dtype=np.int32)
        binary_labels_obj_pos[:obj_pos_pair_pos_inds.size] = 1
        blob_dict[
            'binary_labels_obj_pos_int32'] = binary_labels_obj_pos.astype(
                np.int32, copy=False)
        prd_pos_labels_obj_pos = roidb['max_prd_classes'][
            obj_pos_pair_pos_inds]
        prd_labels_obj_pos = np.zeros(obj_pos_inds.size, dtype=np.int32)
        prd_labels_obj_pos[:obj_pos_pair_pos_inds.
                           size] = prd_pos_labels_obj_pos + 1
        blob_dict['prd_labels_obj_pos_int32'] = prd_labels_obj_pos.astype(
            np.int32, copy=False)
        obj_labels_obj_pos = roidb['max_obj_classes'][obj_pos_inds] + 1
        # 1. set all sbj labels > 0
        sbj_labels_obj_pos = roidb['max_sbj_classes'][obj_pos_inds] + 1
        # 2. find those negative sbj
        max_sbj_overlaps_obj_pos = roidb['max_sbj_overlaps'][obj_pos_inds]
        sbj_neg_inds_obj_pos = np.where(
            max_sbj_overlaps_obj_pos < cfg.TRAIN.FG_THRESH)[0]
        sbj_labels_obj_pos[sbj_neg_inds_obj_pos] = 0
        blob_dict['sbj_labels_obj_pos_int32'] = sbj_labels_obj_pos.astype(
            np.int32, copy=False)
        blob_dict['obj_labels_obj_pos_int32'] = obj_labels_obj_pos.astype(
            np.int32, copy=False)
        # this is for freq bias in RelDN
        blob_dict['sbj_labels_obj_pos_fg_int32'] = roidb['max_sbj_classes'][
            obj_pos_inds].astype(np.int32, copy=False)
        blob_dict['obj_labels_obj_pos_fg_int32'] = roidb['max_obj_classes'][
            obj_pos_inds].astype(np.int32, copy=False)

        sampled_sbj_boxes_obj_pos = roidb['sbj_boxes'][obj_pos_inds]
        sampled_obj_boxes_obj_pos = roidb['obj_boxes'][obj_pos_inds]
        # Scale rois and format as (batch_idx, x1, y1, x2, y2)
        sampled_sbj_rois_obj_pos = sampled_sbj_boxes_obj_pos * im_scale
        sampled_obj_rois_obj_pos = sampled_obj_boxes_obj_pos * im_scale
        repeated_batch_idx = batch_idx * blob_utils.ones(
            (obj_pos_inds.shape[0], 1))
        sampled_sbj_rois_obj_pos = np.hstack(
            (repeated_batch_idx, sampled_sbj_rois_obj_pos))
        sampled_obj_rois_obj_pos = np.hstack(
            (repeated_batch_idx, sampled_obj_rois_obj_pos))
        blob_dict['sbj_rois_obj_pos'] = sampled_sbj_rois_obj_pos
        blob_dict['obj_rois_obj_pos'] = sampled_obj_rois_obj_pos
        sampled_rel_rois_obj_pos = box_utils_rel.rois_union(
            sampled_sbj_rois_obj_pos, sampled_obj_rois_obj_pos)
        blob_dict['rel_rois_obj_pos'] = sampled_rel_rois_obj_pos
        _, inds_unique_obj_pos, inds_reverse_obj_pos = np.unique(
            sampled_obj_rois_obj_pos,
            return_index=True,
            return_inverse=True,
            axis=0)
        assert inds_reverse_obj_pos.shape[0] == sampled_obj_rois_obj_pos.shape[
            0]
        blob_dict['inds_unique_obj_pos'] = inds_unique_obj_pos
        blob_dict['inds_reverse_obj_pos'] = inds_reverse_obj_pos
        if cfg.MODEL.USE_SPATIAL_FEAT:
            sampled_spt_feat_obj_pos = box_utils_rel.get_spt_features(
                sampled_sbj_boxes_obj_pos, sampled_obj_boxes_obj_pos,
                roidb['width'], roidb['height'])
            blob_dict['spt_feat_obj_pos'] = sampled_spt_feat_obj_pos

    return blob_dict
    def _forward(self,
                 data,
                 im_info,
                 do_vis=False,
                 dataset_name=None,
                 roidb=None,
                 use_gt_labels=False,
                 **rpn_kwargs):
        im_data = data
        if self.training:
            roidb = list(map(lambda x: blob_utils.deserialize(x)[0], roidb))
        if dataset_name is not None:
            dataset_name = blob_utils.deserialize(dataset_name)
        else:
            dataset_name = cfg.TRAIN.DATASETS[
                0] if self.training else cfg.TEST.DATASETS[
                    0]  # assuming only one dataset per run

        device_id = im_data.get_device()

        return_dict = {}  # A dict to collect return variables

        blob_conv = self.Conv_Body(im_data)
        if not cfg.MODEL.USE_REL_PYRAMID:
            blob_conv_prd = self.Prd_RCNN.Conv_Body(im_data)

        rpn_ret = self.RPN(blob_conv, im_info, roidb)

        if cfg.FPN.FPN_ON:
            # Retain only the blobs that will be used for RoI heads. `blob_conv` may include
            # extra blobs that are used for RPN proposals, but not for RoI heads.
            blob_conv = blob_conv[-self.num_roi_levels:]
            if not cfg.MODEL.USE_REL_PYRAMID:
                blob_conv_prd = blob_conv_prd[-self.num_roi_levels:]
            else:
                blob_conv_prd = self.RelPyramid(blob_conv)

        if cfg.MODEL.SHARE_RES5 and self.training:
            box_feat, res5_feat = self.Box_Head(blob_conv,
                                                rpn_ret,
                                                use_relu=True)
        else:
            box_feat = self.Box_Head(blob_conv, rpn_ret, use_relu=True)
        cls_score, bbox_pred = self.Box_Outs(box_feat)

        # now go through the predicate branch
        use_relu = False if cfg.MODEL.NO_FC7_RELU else True
        if self.training:
            fg_inds = np.where(rpn_ret['labels_int32'] > 0)[0]
            det_rois = rpn_ret['rois'][fg_inds]
            det_labels = rpn_ret['labels_int32'][fg_inds]
            det_scores = F.softmax(cls_score[fg_inds], dim=1)
            rel_ret = self.RelPN(det_rois, det_labels, det_scores, im_info,
                                 dataset_name, roidb)
            if cfg.MODEL.ADD_SO_SCORES:
                sbj_feat = self.S_Head(blob_conv,
                                       rel_ret,
                                       rois_name='sbj_rois',
                                       use_relu=use_relu)
                obj_feat = self.O_Head(blob_conv,
                                       rel_ret,
                                       rois_name='obj_rois',
                                       use_relu=use_relu)
            else:
                sbj_feat = self.Box_Head(blob_conv,
                                         rel_ret,
                                         rois_name='sbj_rois',
                                         use_relu=use_relu)
                obj_feat = self.Box_Head(blob_conv,
                                         rel_ret,
                                         rois_name='obj_rois',
                                         use_relu=use_relu)
            if cfg.MODEL.USE_NODE_CONTRASTIVE_LOSS or cfg.MODEL.USE_NODE_CONTRASTIVE_SO_AWARE_LOSS or cfg.MODEL.USE_NODE_CONTRASTIVE_P_AWARE_LOSS:
                if cfg.MODEL.ADD_SO_SCORES:
                    # sbj
                    sbj_feat_sbj_pos = self.S_Head(
                        blob_conv,
                        rel_ret,
                        rois_name='sbj_rois_sbj_pos',
                        use_relu=use_relu)
                    obj_feat_sbj_pos = self.O_Head(
                        blob_conv,
                        rel_ret,
                        rois_name='obj_rois_sbj_pos',
                        use_relu=use_relu)
                    # obj
                    sbj_feat_obj_pos = self.S_Head(
                        blob_conv,
                        rel_ret,
                        rois_name='sbj_rois_obj_pos',
                        use_relu=use_relu)
                    obj_feat_obj_pos = self.O_Head(
                        blob_conv,
                        rel_ret,
                        rois_name='obj_rois_obj_pos',
                        use_relu=use_relu)
                else:
                    # sbj
                    sbj_feat_sbj_pos = self.Box_Head(
                        blob_conv,
                        rel_ret,
                        rois_name='sbj_rois_sbj_pos',
                        use_relu=use_relu)
                    obj_feat_sbj_pos = self.Box_Head(
                        blob_conv,
                        rel_ret,
                        rois_name='obj_rois_sbj_pos',
                        use_relu=use_relu)
                    # obj
                    sbj_feat_obj_pos = self.Box_Head(
                        blob_conv,
                        rel_ret,
                        rois_name='sbj_rois_obj_pos',
                        use_relu=use_relu)
                    obj_feat_obj_pos = self.Box_Head(
                        blob_conv,
                        rel_ret,
                        rois_name='obj_rois_obj_pos',
                        use_relu=use_relu)
        else:
            if roidb is not None:
                im_scale = im_info.data.numpy()[:, 2][0]
                im_w = im_info.data.numpy()[:, 1][0]
                im_h = im_info.data.numpy()[:, 0][0]
                sbj_boxes = roidb['sbj_gt_boxes']
                obj_boxes = roidb['obj_gt_boxes']
                sbj_rois = sbj_boxes * im_scale
                obj_rois = obj_boxes * im_scale
                repeated_batch_idx = 0 * blob_utils.ones(
                    (sbj_rois.shape[0], 1))
                sbj_rois = np.hstack((repeated_batch_idx, sbj_rois))
                obj_rois = np.hstack((repeated_batch_idx, obj_rois))
                rel_rois = box_utils_rel.rois_union(sbj_rois, obj_rois)
                rel_ret = {}
                rel_ret['sbj_rois'] = sbj_rois
                rel_ret['obj_rois'] = obj_rois
                rel_ret['rel_rois'] = rel_rois
                if cfg.FPN.FPN_ON and cfg.FPN.MULTILEVEL_ROIS:
                    lvl_min = cfg.FPN.ROI_MIN_LEVEL
                    lvl_max = cfg.FPN.ROI_MAX_LEVEL
                    rois_blob_names = ['sbj_rois', 'obj_rois', 'rel_rois']
                    for rois_blob_name in rois_blob_names:
                        # Add per FPN level roi blobs named like: <rois_blob_name>_fpn<lvl>
                        target_lvls = fpn_utils.map_rois_to_fpn_levels(
                            rel_ret[rois_blob_name][:, 1:5], lvl_min, lvl_max)
                        fpn_utils.add_multilevel_roi_blobs(
                            rel_ret, rois_blob_name, rel_ret[rois_blob_name],
                            target_lvls, lvl_min, lvl_max)
                sbj_det_feat = self.Box_Head(blob_conv,
                                             rel_ret,
                                             rois_name='sbj_rois',
                                             use_relu=True)
                sbj_cls_scores, _ = self.Box_Outs(sbj_det_feat)
                sbj_cls_scores = sbj_cls_scores.data.cpu().numpy()
                obj_det_feat = self.Box_Head(blob_conv,
                                             rel_ret,
                                             rois_name='obj_rois',
                                             use_relu=True)
                obj_cls_scores, _ = self.Box_Outs(obj_det_feat)
                obj_cls_scores = obj_cls_scores.data.cpu().numpy()
                if use_gt_labels:
                    sbj_labels = roidb['sbj_gt_classes']  # start from 0
                    obj_labels = roidb['obj_gt_classes']  # start from 0
                    sbj_scores = np.ones_like(sbj_labels, dtype=np.float32)
                    obj_scores = np.ones_like(obj_labels, dtype=np.float32)
                else:
                    sbj_labels = np.argmax(sbj_cls_scores[:, 1:], axis=1)
                    obj_labels = np.argmax(obj_cls_scores[:, 1:], axis=1)
                    sbj_scores = np.amax(sbj_cls_scores[:, 1:], axis=1)
                    obj_scores = np.amax(obj_cls_scores[:, 1:], axis=1)
                rel_ret['sbj_scores'] = sbj_scores.astype(np.float32,
                                                          copy=False)
                rel_ret['obj_scores'] = obj_scores.astype(np.float32,
                                                          copy=False)
                rel_ret['sbj_labels'] = sbj_labels.astype(
                    np.int32, copy=False) + 1  # need to start from 1
                rel_ret['obj_labels'] = obj_labels.astype(
                    np.int32, copy=False) + 1  # need to start from 1
                rel_ret['all_sbj_labels_int32'] = sbj_labels.astype(np.int32,
                                                                    copy=False)
                rel_ret['all_obj_labels_int32'] = obj_labels.astype(np.int32,
                                                                    copy=False)
                if cfg.MODEL.USE_SPATIAL_FEAT:
                    spt_feat = box_utils_rel.get_spt_features(
                        sbj_boxes, obj_boxes, im_w, im_h)
                    rel_ret['spt_feat'] = spt_feat
                if cfg.MODEL.ADD_SO_SCORES:
                    sbj_feat = self.S_Head(blob_conv,
                                           rel_ret,
                                           rois_name='sbj_rois',
                                           use_relu=use_relu)
                    obj_feat = self.O_Head(blob_conv,
                                           rel_ret,
                                           rois_name='obj_rois',
                                           use_relu=use_relu)
                else:
                    sbj_feat = self.Box_Head(blob_conv,
                                             rel_ret,
                                             rois_name='sbj_rois',
                                             use_relu=use_relu)
                    obj_feat = self.Box_Head(blob_conv,
                                             rel_ret,
                                             rois_name='obj_rois',
                                             use_relu=use_relu)
            else:
                score_thresh = cfg.TEST.SCORE_THRESH
                while score_thresh >= -1e-06:  # a negative value very close to 0.0
                    det_rois, det_labels, det_scores = \
                        self.prepare_det_rois(rpn_ret['rois'], cls_score, bbox_pred, im_info, score_thresh)
                    rel_ret = self.RelPN(det_rois, det_labels, det_scores,
                                         im_info, dataset_name, roidb)
                    valid_len = len(rel_ret['rel_rois'])
                    if valid_len > 0:
                        break
                    logger.info(
                        'Got {} rel_rois when score_thresh={}, changing to {}'.
                        format(valid_len, score_thresh, score_thresh - 0.01))
                    score_thresh -= 0.01
                if cfg.MODEL.ADD_SO_SCORES:
                    det_s_feat = self.S_Head(blob_conv,
                                             rel_ret,
                                             rois_name='det_rois',
                                             use_relu=use_relu)
                    det_o_feat = self.O_Head(blob_conv,
                                             rel_ret,
                                             rois_name='det_rois',
                                             use_relu=use_relu)
                    sbj_feat = det_s_feat[rel_ret['sbj_inds']]
                    obj_feat = det_o_feat[rel_ret['obj_inds']]
                else:
                    det_feat = self.Box_Head(blob_conv,
                                             rel_ret,
                                             rois_name='det_rois',
                                             use_relu=use_relu)
                    sbj_feat = det_feat[rel_ret['sbj_inds']]
                    obj_feat = det_feat[rel_ret['obj_inds']]

        rel_feat = self.Prd_RCNN.Box_Head(blob_conv_prd,
                                          rel_ret,
                                          rois_name='rel_rois',
                                          use_relu=use_relu)

        spo_feat = torch.cat((sbj_feat, rel_feat, obj_feat), dim=1)
        if cfg.MODEL.USE_SPATIAL_FEAT:
            spt_feat = rel_ret['spt_feat']
        else:
            spt_feat = None
        if cfg.MODEL.USE_FREQ_BIAS or cfg.MODEL.RUN_BASELINE:
            sbj_labels = rel_ret['all_sbj_labels_int32']
            obj_labels = rel_ret['all_obj_labels_int32']
        else:
            sbj_labels = None
            obj_labels = None

        # prd_scores is the visual scores. See reldn_heads.py
        prd_scores, prd_bias_scores, prd_spt_scores, ttl_cls_scores, sbj_cls_scores, obj_cls_scores = \
            self.RelDN(spo_feat, spt_feat, sbj_labels, obj_labels, sbj_feat, obj_feat)

        if self.training:
            return_dict['losses'] = {}
            return_dict['metrics'] = {}
            # rpn loss
            rpn_kwargs.update(
                dict((k, rpn_ret[k]) for k in rpn_ret.keys()
                     if (k.startswith('rpn_cls_logits')
                         or k.startswith('rpn_bbox_pred'))))
            loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(
                **rpn_kwargs)
            if cfg.FPN.FPN_ON:
                for i, lvl in enumerate(
                        range(cfg.FPN.RPN_MIN_LEVEL,
                              cfg.FPN.RPN_MAX_LEVEL + 1)):
                    return_dict['losses']['loss_rpn_cls_fpn%d' %
                                          lvl] = loss_rpn_cls[i]
                    return_dict['losses']['loss_rpn_bbox_fpn%d' %
                                          lvl] = loss_rpn_bbox[i]
            else:
                return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls
                return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox
            # bbox loss
            loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses(
                cls_score, bbox_pred, rpn_ret['labels_int32'],
                rpn_ret['bbox_targets'], rpn_ret['bbox_inside_weights'],
                rpn_ret['bbox_outside_weights'])
            return_dict['losses']['loss_cls'] = loss_cls
            return_dict['losses']['loss_bbox'] = loss_bbox
            return_dict['metrics']['accuracy_cls'] = accuracy_cls

            if cfg.MODEL.USE_FREQ_BIAS and not cfg.MODEL.ADD_SCORES_ALL:
                loss_cls_bias, accuracy_cls_bias = reldn_heads.reldn_losses(
                    prd_bias_scores, rel_ret['all_prd_labels_int32'])
                return_dict['losses']['loss_cls_bias'] = loss_cls_bias
                return_dict['metrics']['accuracy_cls_bias'] = accuracy_cls_bias
            if cfg.MODEL.USE_SPATIAL_FEAT and not cfg.MODEL.ADD_SCORES_ALL:
                loss_cls_spt, accuracy_cls_spt = reldn_heads.reldn_losses(
                    prd_spt_scores, rel_ret['all_prd_labels_int32'])
                return_dict['losses']['loss_cls_spt'] = loss_cls_spt
                return_dict['metrics']['accuracy_cls_spt'] = accuracy_cls_spt
            if cfg.MODEL.ADD_SCORES_ALL:
                loss_cls_ttl, accuracy_cls_ttl = reldn_heads.reldn_losses(
                    ttl_cls_scores, rel_ret['all_prd_labels_int32'])
                return_dict['losses']['loss_cls_ttl'] = loss_cls_ttl
                return_dict['metrics']['accuracy_cls_ttl'] = accuracy_cls_ttl
            else:
                loss_cls_prd, accuracy_cls_prd = reldn_heads.reldn_losses(
                    prd_scores, rel_ret['all_prd_labels_int32'])
                return_dict['losses']['loss_cls_prd'] = loss_cls_prd
                return_dict['metrics']['accuracy_cls_prd'] = accuracy_cls_prd
            if cfg.MODEL.USE_NODE_CONTRASTIVE_LOSS or cfg.MODEL.USE_NODE_CONTRASTIVE_SO_AWARE_LOSS or cfg.MODEL.USE_NODE_CONTRASTIVE_P_AWARE_LOSS:
                # sbj
                rel_feat_sbj_pos = self.Prd_RCNN.Box_Head(
                    blob_conv_prd,
                    rel_ret,
                    rois_name='rel_rois_sbj_pos',
                    use_relu=use_relu)
                spo_feat_sbj_pos = torch.cat(
                    (sbj_feat_sbj_pos, rel_feat_sbj_pos, obj_feat_sbj_pos),
                    dim=1)
                if cfg.MODEL.USE_SPATIAL_FEAT:
                    spt_feat_sbj_pos = rel_ret['spt_feat_sbj_pos']
                else:
                    spt_feat_sbj_pos = None
                if cfg.MODEL.USE_FREQ_BIAS or cfg.MODEL.RUN_BASELINE:
                    sbj_labels_sbj_pos_fg = rel_ret[
                        'sbj_labels_sbj_pos_fg_int32']
                    obj_labels_sbj_pos_fg = rel_ret[
                        'obj_labels_sbj_pos_fg_int32']
                else:
                    sbj_labels_sbj_pos_fg = None
                    obj_labels_sbj_pos_fg = None
                _, prd_bias_scores_sbj_pos, _, ttl_cls_scores_sbj_pos, _, _ = \
                    self.RelDN(spo_feat_sbj_pos, spt_feat_sbj_pos, sbj_labels_sbj_pos_fg, obj_labels_sbj_pos_fg, sbj_feat_sbj_pos, obj_feat_sbj_pos)
                # obj
                rel_feat_obj_pos = self.Prd_RCNN.Box_Head(
                    blob_conv_prd,
                    rel_ret,
                    rois_name='rel_rois_obj_pos',
                    use_relu=use_relu)
                spo_feat_obj_pos = torch.cat(
                    (sbj_feat_obj_pos, rel_feat_obj_pos, obj_feat_obj_pos),
                    dim=1)
                if cfg.MODEL.USE_SPATIAL_FEAT:
                    spt_feat_obj_pos = rel_ret['spt_feat_obj_pos']
                else:
                    spt_feat_obj_pos = None
                if cfg.MODEL.USE_FREQ_BIAS or cfg.MODEL.RUN_BASELINE:
                    sbj_labels_obj_pos_fg = rel_ret[
                        'sbj_labels_obj_pos_fg_int32']
                    obj_labels_obj_pos_fg = rel_ret[
                        'obj_labels_obj_pos_fg_int32']
                else:
                    sbj_labels_obj_pos_fg = None
                    obj_labels_obj_pos_fg = None
                _, prd_bias_scores_obj_pos, _, ttl_cls_scores_obj_pos, _, _ = \
                    self.RelDN(spo_feat_obj_pos, spt_feat_obj_pos, sbj_labels_obj_pos_fg, obj_labels_obj_pos_fg, sbj_feat_obj_pos, obj_feat_obj_pos)
                if cfg.MODEL.USE_NODE_CONTRASTIVE_LOSS:
                    loss_contrastive_sbj, loss_contrastive_obj = reldn_heads.reldn_contrastive_losses(
                        ttl_cls_scores_sbj_pos, ttl_cls_scores_obj_pos,
                        rel_ret)
                    return_dict['losses'][
                        'loss_contrastive_sbj'] = loss_contrastive_sbj * cfg.MODEL.NODE_CONTRASTIVE_WEIGHT
                    return_dict['losses'][
                        'loss_contrastive_obj'] = loss_contrastive_obj * cfg.MODEL.NODE_CONTRASTIVE_WEIGHT
                if cfg.MODEL.USE_NODE_CONTRASTIVE_SO_AWARE_LOSS:
                    loss_so_contrastive_sbj, loss_so_contrastive_obj = reldn_heads.reldn_so_contrastive_losses(
                        ttl_cls_scores_sbj_pos, ttl_cls_scores_obj_pos,
                        rel_ret)
                    return_dict['losses'][
                        'loss_so_contrastive_sbj'] = loss_so_contrastive_sbj * cfg.MODEL.NODE_CONTRASTIVE_SO_AWARE_WEIGHT
                    return_dict['losses'][
                        'loss_so_contrastive_obj'] = loss_so_contrastive_obj * cfg.MODEL.NODE_CONTRASTIVE_SO_AWARE_WEIGHT
                if cfg.MODEL.USE_NODE_CONTRASTIVE_P_AWARE_LOSS:
                    loss_p_contrastive_sbj, loss_p_contrastive_obj = reldn_heads.reldn_p_contrastive_losses(
                        ttl_cls_scores_sbj_pos, ttl_cls_scores_obj_pos,
                        prd_bias_scores_sbj_pos, prd_bias_scores_obj_pos,
                        rel_ret)
                    return_dict['losses'][
                        'loss_p_contrastive_sbj'] = loss_p_contrastive_sbj * cfg.MODEL.NODE_CONTRASTIVE_P_AWARE_WEIGHT
                    return_dict['losses'][
                        'loss_p_contrastive_obj'] = loss_p_contrastive_obj * cfg.MODEL.NODE_CONTRASTIVE_P_AWARE_WEIGHT

            # pytorch0.4 bug on gathering scalar(0-dim) tensors
            for k, v in return_dict['losses'].items():
                return_dict['losses'][k] = v.unsqueeze(0)
            for k, v in return_dict['metrics'].items():
                return_dict['metrics'][k] = v.unsqueeze(0)
        else:
            # Testing
            return_dict['sbj_rois'] = rel_ret['sbj_rois']
            return_dict['obj_rois'] = rel_ret['obj_rois']
            return_dict['sbj_labels'] = rel_ret['sbj_labels']
            return_dict['obj_labels'] = rel_ret['obj_labels']
            return_dict['sbj_scores'] = rel_ret['sbj_scores']
            return_dict['obj_scores'] = rel_ret['obj_scores']
            return_dict['prd_scores'] = prd_scores
            if cfg.MODEL.USE_FREQ_BIAS:
                return_dict['prd_scores_bias'] = prd_bias_scores
            if cfg.MODEL.USE_SPATIAL_FEAT:
                return_dict['prd_scores_spt'] = prd_spt_scores
            if cfg.MODEL.ADD_SCORES_ALL:
                return_dict['prd_ttl_scores'] = ttl_cls_scores
            if do_vis:
                return_dict['blob_conv'] = blob_conv
                return_dict['blob_conv_prd'] = blob_conv_prd

        return return_dict
Пример #3
0
    def forward(self,
                det_rois,
                det_labels,
                det_scores,
                im_info,
                dataset_name,
                roidb=None):
        """
        det_rois: feature maps from the backbone network. (Variable)
        im_info: (CPU Variable)
        roidb: (list of ndarray)
        """

        # Get pairwise proposals first
        if roidb is not None:
            # we always feed one image per batch during training
            assert len(roidb) == 1

        sbj_inds = np.repeat(np.arange(det_rois.shape[0]), det_rois.shape[0])
        obj_inds = np.tile(np.arange(det_rois.shape[0]), det_rois.shape[0])
        # remove self paired rois
        if det_rois.shape[
                0] > 1:  # no pairs to remove when there is at most one detection
            sbj_inds, obj_inds = self.remove_self_pairs(
                det_rois.shape[0], sbj_inds, obj_inds)
        sbj_rois = det_rois[sbj_inds]
        obj_rois = det_rois[obj_inds]

        im_scale = im_info.data.numpy()[:, 2][0]
        sbj_boxes = sbj_rois[:, 1:] / im_scale
        obj_boxes = obj_rois[:, 1:] / im_scale
        # filters out those roi pairs whose boxes are not overlapping in the original scales
        if cfg.MODEL.USE_OVLP_FILTER:
            ovlp_so = box_utils_rel.bbox_pair_overlaps(
                sbj_boxes.astype(dtype=np.float32, copy=False),
                obj_boxes.astype(dtype=np.float32, copy=False))
            ovlp_inds = np.where(ovlp_so > 0)[0]
            sbj_inds = sbj_inds[ovlp_inds]
            obj_inds = obj_inds[ovlp_inds]
            sbj_rois = sbj_rois[ovlp_inds]
            obj_rois = obj_rois[ovlp_inds]
            sbj_boxes = sbj_boxes[ovlp_inds]
            obj_boxes = obj_boxes[ovlp_inds]

        return_dict = {}
        if self.training:
            # Add binary relationships
            blobs_out = self.RelPN_GenerateProposalLabels(
                sbj_rois, obj_rois, det_rois, roidb, im_info)
            return_dict.update(blobs_out)
        else:
            sbj_labels = det_labels[sbj_inds]
            obj_labels = det_labels[obj_inds]
            sbj_scores = det_scores[sbj_inds]
            obj_scores = det_scores[obj_inds]
            rel_rois = box_utils_rel.rois_union(sbj_rois, obj_rois)
            return_dict['det_rois'] = det_rois
            return_dict['sbj_inds'] = sbj_inds
            return_dict['obj_inds'] = obj_inds
            return_dict['sbj_rois'] = sbj_rois
            return_dict['obj_rois'] = obj_rois
            return_dict['rel_rois'] = rel_rois
            return_dict['sbj_labels'] = sbj_labels
            return_dict['obj_labels'] = obj_labels
            return_dict['sbj_scores'] = sbj_scores
            return_dict['obj_scores'] = obj_scores
            return_dict['fg_size'] = np.array([sbj_rois.shape[0]],
                                              dtype=np.int32)

            im_scale = im_info.data.numpy()[:, 2][0]
            im_w = im_info.data.numpy()[:, 1][0]
            im_h = im_info.data.numpy()[:, 0][0]
            if cfg.MODEL.USE_SPATIAL_FEAT:
                spt_feat = box_utils_rel.get_spt_features(
                    sbj_boxes, obj_boxes, im_w, im_h)
                return_dict['spt_feat'] = spt_feat
            if cfg.MODEL.USE_FREQ_BIAS or cfg.MODEL.RUN_BASELINE:
                return_dict['all_sbj_labels_int32'] = sbj_labels.astype(
                    np.int32, copy=False) - 1  # det_labels start from 1
                return_dict['all_obj_labels_int32'] = obj_labels.astype(
                    np.int32, copy=False) - 1  # det_labels start from 1
            if cfg.FPN.FPN_ON and cfg.FPN.MULTILEVEL_ROIS:
                lvl_min = cfg.FPN.ROI_MIN_LEVEL
                lvl_max = cfg.FPN.ROI_MAX_LEVEL
                # when use min_rel_area, the same sbj/obj area could be mapped to different feature levels
                # when they are associated with different relationships
                # Thus we cannot get det_rois features then gather sbj/obj features
                # The only way is gather sbj/obj per relationship, thus need to return sbj_rois/obj_rois
                rois_blob_names = ['det_rois', 'rel_rois']
                for rois_blob_name in rois_blob_names:
                    # Add per FPN level roi blobs named like: <rois_blob_name>_fpn<lvl>
                    target_lvls = fpn_utils.map_rois_to_fpn_levels(
                        return_dict[rois_blob_name][:, 1:5], lvl_min, lvl_max)
                    fpn_utils.add_multilevel_roi_blobs(
                        return_dict, rois_blob_name,
                        return_dict[rois_blob_name], target_lvls, lvl_min,
                        lvl_max)

        return return_dict