Example #1
0
    def _rel_forward_train(self,
                           x,
                           gt_bboxes,
                           gt_labels,
                           gt_rel,
                           gt_instid,
                           rel_train_cfg,
                           semantic_feat=None,
                           im_width=None,
                           im_height=None,
                           debug_image=None,
                           debug_filename=None):
        assert gt_bboxes.shape[0] == gt_labels.shape[0] == gt_instid.shape[0]

        combined_bboxes = gt_bboxes
        combined_labels = gt_labels
        sbj_bboxes, sbj_labels, sbj_idxs, obj_bboxes, obj_labels, obj_idxs =\
            sample_pairs(combined_bboxes, combined_labels)

        with torch.no_grad():
            mask_roi_extractor = self.mask_roi_extractor[-1]
            mask_rois = bbox2roi([combined_bboxes])
            mask_feats = mask_roi_extractor(
                x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
            if self.with_semantic and 'mask' in self.semantic_fusion:
                mask_semantic_feat = self.semantic_roi_extractor(
                    [semantic_feat], mask_rois)
                if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]:
                    mask_semantic_feat = F.adaptive_avg_pool2d(
                        mask_semantic_feat, mask_feats.shape[-2:])
                mask_feats += mask_semantic_feat

        relations, relation_targets = \
            assign_pairs(sbj_bboxes, sbj_labels, sbj_idxs,
                         obj_bboxes, obj_labels, obj_idxs,
                         gt_bboxes, gt_labels, gt_rel, gt_instid)

        num_relations = relations.shape[0]
        num_relation_cats = 31  # hard coded
        targets = relations.new_zeros((num_relations, num_relation_cats),
                                      dtype=torch.long)
        for i in range(num_relations):
            targets[i, relation_targets[i]] = 1

        # get union bboxes
        sbj_bboxes = combined_bboxes[relations[:, -2].long(), :]
        obj_bboxes = combined_bboxes[relations[:, -1].long(), :]
        union_bboxes = bbox_union(sbj_bboxes, obj_bboxes)
        union_rois = bbox2roi([union_bboxes])

        # get visual features
        sbj_feats = mask_feats[relations[:, -2].long(), ...]
        obj_feats = mask_feats[relations[:, -1].long(), ...]

        with torch.no_grad():
            union_feats = mask_roi_extractor(
                x[:len(mask_roi_extractor.featmap_strides)], union_rois)
            if self.with_semantic and 'mask' in self.semantic_fusion:
                union_semantic_feat = self.semantic_roi_extractor(
                    [semantic_feat], union_rois)
                if union_semantic_feat.shape[-2:] != union_feats.shape[-2:]:
                    union_semantic_feat = F.adaptive_avg_pool2d(
                        union_semantic_feat, union_feats.shape[-2:])
                union_feats += union_semantic_feat
        '''
        ## debug
        if debug_image is not None:
            import os
            import matplotlib.pyplot as plt
            debug_dir = os.path.join('./tmp/{}'.format(debug_filename[:-4]))
            if not os.path.exists(debug_dir):
                os.makedirs(debug_dir)

            debug_num_relation = relations.shape[0]
            for i in range(debug_num_relation):
                sbj_label = relations[i, 0].to(dtype=torch.int)
                obj_label = relations[i, 1].to(dtype=torch.int)
                rel_label = relations[i, 2].to(dtype=torch.int)
                sbj_index = relations[i, 3].to(dtype=torch.int)
                obj_index = relations[i, 4].to(dtype=torch.int)

                sbj_bbox = combined_bboxes[sbj_index, :4]
                obj_bbox = combined_bboxes[obj_index, :4]

                union_bbox = union_bboxes[i, :4]

                # plot
                fig, ax = plt.subplots(1, 1, figsize=(10, 10))
                ax.imshow(debug_image, cmap=plt.cm.gray)

                x1, y1, x2, y2 = sbj_bbox
                ax.add_artist(plt.Rectangle((x1, y1), x2-x1+1, y2-y1+1,
                                            fill=False, color='r'))
                sbj_text = PicDatasetV20.CATEGORIES[sbj_label-1]['name']
                ax.add_artist(plt.Text(x1, y1, sbj_text, size='x-large',
                                       color='r'))

                x1, y1, x2, y2 = obj_bbox
                ax.add_artist(plt.Rectangle((x1, y1), x2-x1+1, y2-y1+1,
                                            fill=False, color='g'))
                obj_text = PicDatasetV20.CATEGORIES[obj_label-1]['name']
                ax.add_artist(plt.Text(x1, y1, obj_text, size='x-large',
                                       color='g'))

                x1, y1, x2, y2 = union_bbox
                ax.add_artist(plt.Rectangle((x1, y1), x2-x1+1, y2-y1+1,
                                            fill=False, color='w'))
                rel_text = PicDatasetV20.REL_CATEGORIES[rel_label-1]['name'] \
                    if rel_label > 0 else 'None'
                ax.add_artist(plt.Text(x1, y1, rel_text, size='x-large',
                                       color='w',
                                       bbox=dict(facecolor='k', alpha=0.5)))

                rel_text = PicDatasetV20.REL_CATEGORIES[rel_label-1]['name'] \
                    if rel_label > 0 else 'None'
                title = '<{}, {}, {}>'.format(sbj_text, obj_text, rel_text)
                ax.set_title(title)
                ax.axis('off')

                savename = os.path.join(debug_dir, '{:05d}.png'.format(i))
                plt.savefig(savename, dpi=100)
                plt.close()
        '''

        sbj_feats = sbj_feats.view(num_relations, -1)
        obj_feats = obj_feats.view(num_relations, -1)
        union_feats = union_feats.view(num_relations, -1)

        assert sbj_feats.shape == obj_feats.shape == union_feats.shape

        visual_features = torch.cat([sbj_feats, obj_feats, union_feats],
                                    dim=-1)

        spatial_features = get_spatial_feature(sbj_bboxes, obj_bboxes,
                                               im_width, im_height)

        prd_vis_scores, prd_bias_scores, prd_spt_scores \
            = self.reldn_head(visual_features,
                              sbj_labels=sbj_labels,
                              obj_labels=obj_labels,
                              spt_feat=spatial_features,
                              sbj_feat=sbj_feats,
                              obj_feat=obj_feats)

        label_weights = prd_vis_scores.new_ones(len(prd_vis_scores),
                                                dtype=torch.long)

        loss = self.reldn_head.loss(prd_vis_scores,
                                    targets,
                                    label_weights,
                                    prd_spt_score=prd_spt_scores)

        return loss
Example #2
0
    def _rel_forward_train_binary(self,
                                  x,
                                  gt_bboxes,
                                  gt_labels,
                                  gt_rel,
                                  gt_instid,
                                  debug_image=None,
                                  debug_filename=None):

        assert gt_bboxes.shape[0] == gt_labels.shape[0] == gt_instid.shape[0]

        try:
            sbj_bboxes, sbj_labels, sbj_idxs, obj_bboxes, obj_labels, obj_idxs\
                = sample_pairs(gt_bboxes, gt_labels)
        except:
            print(gt_bboxes, debug_filename)

        mask_roi_extractor = self.mask_roi_extractor[-1]

        with torch.no_grad():
            mask_rois = bbox2roi([gt_bboxes])
            mask_feats = mask_roi_extractor(
                x[:len(mask_roi_extractor.featmap_strides)], mask_rois)

        relations = assign_pairs(sbj_bboxes, sbj_labels, sbj_idxs, obj_bboxes,
                                 obj_labels, obj_idxs, gt_bboxes, gt_labels,
                                 gt_rel, gt_instid)

        try:
            targets = relations[:, 2].long()
        except:
            print(debug_filename, relations, gt_bboxes, sbj_bboxes, obj_bboxes)

        pos_indexs = (targets > 0).nonzero()
        neg_indexs = (targets == 0).nonzero()
        num_neg = len(neg_indexs)
        num_pos = len(pos_indexs)
        randperm = torch.randperm(num_neg)
        neg_indexs = neg_indexs[randperm[:min(num_neg, num_pos * 3)]]
        indexs = torch.cat((pos_indexs, neg_indexs), dim=0).squeeze(1)

        relations = relations[indexs, :]
        targets = relations[:, 2].long()

        sbj_bboxes = gt_bboxes[relations[:, -2].long(), :4]
        obj_bboxes = gt_bboxes[relations[:, -1].long(), :4]
        union_bboxes = bbox_union(sbj_bboxes, obj_bboxes)
        union_rois = bbox2roi([union_bboxes])

        with torch.no_grad():
            union_feats = mask_roi_extractor(
                x[:len(mask_roi_extractor.featmap_strides)], union_rois)

        if debug_image is not None:
            import os
            import matplotlib.pyplot as plt
            debug_dir = os.path.join('./tmp/{}'.format(debug_filename[:-4]))
            if not os.path.exists(debug_dir):
                os.makedirs(debug_dir)

            debug_num_relation = relations.shape[0]
            for i in range(debug_num_relation):
                sbj_label = relations[i, 0].to(dtype=torch.int)
                obj_label = relations[i, 1].to(dtype=torch.int)
                rel_label = relations[i, 2].to(dtype=torch.int)
                sbj_index = relations[i, 3].to(dtype=torch.int)
                obj_index = relations[i, 4].to(dtype=torch.int)

                sbj_bbox = gt_bboxes[sbj_index, :4]
                obj_bbox = gt_bboxes[obj_index, :4]

                union_bbox = union_bboxes[i, :4]

                # plot
                fig, ax = plt.subplots(1, 1, figsize=(10, 10))
                ax.imshow(debug_image, cmap=plt.cm.gray)

                x1, y1, x2, y2 = sbj_bbox
                ax.add_artist(
                    plt.Rectangle((x1, y1),
                                  x2 - x1 + 1,
                                  y2 - y1 + 1,
                                  fill=False,
                                  color='r'))
                sbj_text = PicDatasetV20.CATEGORIES[sbj_label - 1]['name']
                ax.add_artist(
                    plt.Text(x1, y1, sbj_text, size='x-large', color='r'))

                x1, y1, x2, y2 = obj_bbox
                ax.add_artist(
                    plt.Rectangle((x1, y1),
                                  x2 - x1 + 1,
                                  y2 - y1 + 1,
                                  fill=False,
                                  color='g'))
                obj_text = PicDatasetV20.CATEGORIES[obj_label - 1]['name']
                ax.add_artist(
                    plt.Text(x1, y1, obj_text, size='x-large', color='g'))

                x1, y1, x2, y2 = union_bbox
                ax.add_artist(
                    plt.Rectangle((x1, y1),
                                  x2 - x1 + 1,
                                  y2 - y1 + 1,
                                  fill=False,
                                  color='w'))
                rel_text = PicDatasetV20.REL_CATEGORIES[rel_label-1]['name']\
                    if rel_label > 0 else 'None'
                ax.add_artist(
                    plt.Text(x1,
                             y1,
                             rel_text,
                             size='x-large',
                             color='w',
                             bbox=dict(facecolor='k', alpha=0.5)))

                #
                rel_text = PicDatasetV20.REL_CATEGORIES[rel_label-1]['name'] \
                    if rel_label > 0 else 'None'
                title = '<{}, {}, {}>'.format(sbj_text, obj_text, rel_text)
                ax.set_title(title)
                ax.axis('off')

                savename = os.path.join(debug_dir, '{:05d}.png'.format(i))
                plt.savefig(savename, dpi=100)
                plt.close()

        sbj_feats = mask_feats[relations[:, -2].long(), ...]
        obj_feats = mask_feats[relations[:, -1].long(), ...]

        num_relations = relations.shape[0]
        sbj_feats = sbj_feats.view(num_relations, -1)
        obj_feats = obj_feats.view(num_relations, -1)
        union_feats = union_feats.view(num_relations, -1)

        pred = self.reldn_binary_head(sbj_feats, obj_feats, union_feats)
        loss = self.reldn_binary_head.loss(pred, targets > 0)

        return loss
Example #3
0
    def _rel_forward_train(self,
                           x,
                           gt_bboxes,
                           gt_labels,
                           gt_rel,
                           gt_instid,
                           rel_train_cfg,
                           im_width=None,
                           im_height=None,
                           debug_image=None,
                           debug_filename=None):
        """
        :param x:
        :param gt_bboxes: n x 4 (tensor)
        :param gt_labels: n x 1 (tensor)
        :param det_bboxes: m x 5 (tensor)
        :param gt_rel: k x 3 (tensor)
        :param gt_instid: k x 1 (tensor)
        :param rel_train_cfg:
        :param filename: debug
        :return:
        """

        assert gt_bboxes.shape[0] == gt_labels.shape[0] == gt_instid.shape[0]

        combined_bboxes = gt_bboxes
        combined_labels = gt_labels

        sbj_bboxes, sbj_labels, sbj_idxs, obj_bboxes, obj_labels, obj_idxs = \
            sample_pairs(combined_bboxes, combined_labels)

        # extract mask features
        with torch.no_grad():
            bbox_roi_extractor = self.bbox_roi_extractor[-1]
            bbox_rois = bbox2roi([gt_bboxes])
            bbox_feats = bbox_roi_extractor(
                x[:len(bbox_roi_extractor.featmap_strides)], bbox_rois)

        # assign candidate pairs to a relation class
        # (including no-relationship class)
        relations, relation_targets = \
            assign_pairs(sbj_bboxes, sbj_labels, sbj_idxs,
                         obj_bboxes, obj_labels, obj_idxs,
                         gt_bboxes, gt_labels, gt_rel, gt_instid)

        num_relations = relations.shape[0]
        num_relation_cats = 11  # hard coded
        targets = relations.new_zeros((num_relations, num_relation_cats))
        for i in range(num_relations):
            targets[i, relation_targets[i]] = 1

        # get union bboxes
        sbj_bboxes = combined_bboxes[relations[:, -2].long(), :]
        obj_bboxes = combined_bboxes[relations[:, -1].long(), :]
        union_bboxes = bbox_union(sbj_bboxes, obj_bboxes)
        union_rois = bbox2roi([union_bboxes])

        # get visual features
        sbj_feats = bbox_feats[relations[:, -2].long(), ...]
        sbj_feats = self.sa(sbj_feats)

        obj_feats = bbox_feats[relations[:, -1].long(), ...]

        with torch.no_grad():
            union_feats = bbox_roi_extractor(
                x[:len(bbox_roi_extractor.featmap_strides)], union_rois)

        sbj_feats = sbj_feats.view(num_relations, -1)
        obj_feats = obj_feats.view(num_relations, -1)
        union_feats = union_feats.view(num_relations, -1)

        assert sbj_feats.shape == obj_feats.shape == union_feats.shape

        visual_features = torch.cat([sbj_feats, obj_feats, union_feats],
                                    dim=-1)

        spatial_features = get_spatial_feature(sbj_bboxes, obj_bboxes,
                                               im_width, im_height)

        prd_vis_scores, prd_bias_scores, prd_spt_scores \
            = self.reldn_head(visual_features,
                              sbj_labels=sbj_labels,
                              obj_labels=obj_labels,
                              spt_feat=spatial_features,
                              sbj_feat=sbj_feats,
                              obj_feat=obj_feats)

        label_weights = prd_vis_scores.new_ones(len(prd_vis_scores),
                                                dtype=torch.long)

        loss = self.reldn_head.loss(prd_vis_scores,
                                    targets,
                                    label_weights,
                                    prd_spt_score=prd_spt_scores)

        return loss
Example #4
0
    def _rel_forward_test(self,
                          x,
                          det_bboxes,
                          det_labels,
                          det_masks,
                          scale_factor,
                          ori_shape,
                          im_width=None,
                          im_height=None,
                          semantic_feat=None):

        assert det_labels.shape[0] == det_bboxes.shape[0] == det_masks.shape[0]
        if isinstance(det_masks, torch.Tensor):
            det_masks = det_masks.sigmoid().cpu().numpy()

        rel_test_cfg = self.test_cfg.rel
        run_baseline = rel_test_cfg.run_baseline

        # build pairs
        sbj_bboxes, sbj_labels, sbj_idxs, obj_bboxes, obj_labels, obj_idxs =\
            sample_pairs(det_bboxes, det_labels + 1)

        if sbj_bboxes is None:
            return None

        num_pairs = len(sbj_idxs)

        # extract roi features
        with torch.no_grad():
            mask_roi_extractor = self.mask_roi_extractor[-1]
            mask_rois = bbox2roi([det_bboxes])
            mask_feats = mask_roi_extractor(
                x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
            if self.with_semantic and 'mask' in self.semantic_fusion:
                mask_semantic_feat = self.semantic_roi_extractor(
                    [semantic_feat], mask_rois)
                if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]:
                    mask_semantic_feat = F.adaptive_avg_pool2d(
                        mask_semantic_feat, mask_feats.shape[-2:])
                mask_feats += mask_semantic_feat

        # get visual features
        sbj_feats = mask_feats[sbj_idxs, ...]
        obj_feats = mask_feats[obj_idxs, ...]

        # union
        union_bboxes = bbox_union(sbj_bboxes, obj_bboxes)
        union_rois = bbox2roi([union_bboxes])
        with torch.no_grad():
            union_feats = mask_roi_extractor(
                x[:len(mask_roi_extractor.featmap_strides)], union_rois)
            if self.with_semantic and 'mask' in self.semantic_fusion:
                union_semantic_feat = self.semantic_roi_extractor(
                    [semantic_feat], union_rois)
                if union_semantic_feat.shape[-2:] != union_feats.shape[-2:]:
                    union_semantic_feat = F.adaptive_avg_pool2d(
                        union_semantic_feat, union_feats.shape[-2:])
                union_feats += union_semantic_feat

        sbj_feats = sbj_feats.view(num_pairs, -1)
        obj_feats = obj_feats.view(num_pairs, -1)
        union_feats = union_feats.view(num_pairs, -1)

        assert sbj_feats.shape == obj_feats.shape == union_feats.shape

        visual_features = torch.cat([sbj_feats, obj_feats, union_feats],
                                    dim=-1)

        spatial_features = get_spatial_feature(sbj_bboxes, obj_bboxes,
                                               im_width, im_height)

        prd_vis_scores, prd_bias_scores, prd_spt_scores \
            = self.reldn_head(visual_features,
                              sbj_labels=sbj_labels,
                              obj_labels=obj_labels,
                              spt_feat=spatial_features,
                              sbj_feat=sbj_feats,
                              obj_feat=obj_feats,
                              run_baseline=run_baseline)

        # detect relations
        ret_bboxes = det_bboxes.cpu().numpy()
        ret_labels = det_labels.cpu().numpy()
        ret_masks = det_masks
        ret_relations = []

        # fill ret_relations
        if run_baseline:
            thresh = rel_test_cfg.thresh

            sbj_idxs = sbj_idxs.cpu().numpy()
            obj_idxs = obj_idxs.cpu().numpy()
            prd_bias_scores = prd_bias_scores.cpu().numpy()
            # never predict __no_relation__ for frequency prior
            for i in range(prd_bias_scores.shape[0]):
                sbj_id = sbj_idxs[i]
                obj_id = obj_idxs[i]
                rel_scores = prd_bias_scores[i, :]
                rel_scores[0] = 0
                rel_ids = np.where(rel_scores > thresh)

                ret_relations.extend(
                    [[sbj_id, obj_id, rel_id, rel_scores[rel_id]]
                     for rel_id in rel_ids])
        else:
            prd_vis_scores = prd_vis_scores.cpu().numpy()
            prd_bias_scores = prd_bias_scores.cpu().numpy()
            prd_bias_scores[:, 0] = 0
            prd_spt_scores = prd_spt_scores.cpu().numpy()

            prd_score = (prd_vis_scores + prd_spt_scores) * prd_bias_scores
            for i in range(prd_score.shape[0]):
                sbj_id = sbj_idxs[i]
                obj_id = obj_idxs[i]
                rel_scores = prd_score[i, :]
                rel_ids = rel_scores.argsort()[-5:]

                ret_relations.extend(
                    [[sbj_id, obj_id, rel_ids, rel_scores[rel_ids]]])

        ret = {
            'ret_bbox': ret_bboxes,
            'ret_mask': ret_masks,
            'ret_label': ret_labels,
            'ret_relation': ret_relations,
            'scale_factor': scale_factor,
            'ori_shape': ori_shape
        }

        return ret
Example #5
0
    def _rel_forward_test(self,
                          x,
                          det_bboxes,
                          det_labels,
                          scale_factor,
                          ori_shape,
                          filename=None,
                          im_width=None,
                          im_height=None):

        assert det_labels.shape[0] == det_bboxes.shape[0]

        rel_test_cfg = self.test_cfg.rel
        run_baseline = rel_test_cfg.run_baseline

        # build pairs
        sbj_bboxes, sbj_labels, sbj_idxs, obj_bboxes, obj_labels, obj_idxs = \
            sample_pairs(det_bboxes, det_labels + 1,
                         overlap=True, overlap_th=0.4, test=True)

        ret = {'predictions': [], 'hoi_prediction': []}
        if sbj_bboxes is None or sbj_bboxes.shape[0] == 0:
            return ret

        num_pairs = len(sbj_idxs)

        # extract roi features
        with torch.no_grad():
            bbox_roi_extractor = self.bbox_roi_extractor[-1]
            bbox_rois = bbox2roi([det_bboxes])
            bbox_feats = bbox_roi_extractor(
                x[:len(bbox_roi_extractor.featmap_strides)], bbox_rois)

        sbj_feats = bbox_feats[sbj_idxs, ...]
        obj_feats = bbox_feats[obj_idxs, ...]
        union_bboxes = bbox_union(sbj_bboxes, obj_bboxes)
        union_rois = bbox2roi([union_bboxes])
        with torch.no_grad():
            union_feats = bbox_roi_extractor(
                x[:len(bbox_roi_extractor.featmap_strides)], union_rois)

        sbj_feats = sbj_feats.view(num_pairs, -1)
        obj_feats = obj_feats.view(num_pairs, -1)
        union_feats = union_feats.view(num_pairs, -1)

        assert sbj_feats.shape == obj_feats.shape == union_feats.shape
        visual_features = torch.cat([sbj_feats, obj_feats, union_feats],
                                    dim=-1)
        spatial_features = get_spatial_feature(sbj_bboxes, obj_bboxes,
                                               im_width, im_height)
        prd_vis_scores, prd_bias_scores, prd_spt_scores \
            = self.reldn_head(visual_features,
                              sbj_labels=sbj_labels,
                              obj_labels=obj_labels,
                              spt_feat=spatial_features,
                              sbj_feat=sbj_feats,
                              obj_feat=obj_feats,
                              run_baseline=run_baseline)

        # detect relations
        ret_bboxes = det_bboxes.cpu().numpy()
        ret_labels = det_labels.cpu().numpy()

        thresh = rel_test_cfg.thresh

        sbj_idxs = sbj_idxs.cpu().numpy()
        obj_idxs = obj_idxs.cpu().numpy()

        entity_index_to_output_index = {}
        unique_sbj_idxs = list(np.unique(sbj_idxs))
        unique_obj_idxs = list(np.unique(obj_idxs))
        unique_entity_idxs = unique_sbj_idxs + unique_obj_idxs
        for i, entity_index in enumerate(unique_entity_idxs):
            entity_index_to_output_index[entity_index] = i

            bbox = ret_bboxes[entity_index, :4]
            cat = ret_labels[entity_index]
            score = ret_bboxes[entity_index, -1]

            ret['predictions'].append({
                'bbox': bbox.tolist(),
                'category_id': str(cat + 1),
                'score': float(score)
            })

        if run_baseline:
            prd_bias_scores = prd_bias_scores.cpu().numpy()
            for i in range(prd_bias_scores.shape[0]):
                sbj_id = sbj_idxs[i]
                obj_id = obj_idxs[i]
                rel_scores = prd_bias_scores[i, :]
                rel_scores[0] = 0

                obj_cat = ret_labels[obj_id] + 1
                if obj_cat == 8 or obj_cat == 7 or obj_cat == 6:  # horse
                    rel_ids = [6]  # ride
                elif obj_cat == 9:
                    rel_ids = [8]
                else:
                    rel_ids = np.where(rel_scores > thresh)[0]

                sbj_output_index = entity_index_to_output_index[sbj_id]
                obj_output_index = entity_index_to_output_index[obj_id]

                for rel_id in rel_ids:
                    ret['hoi_prediction'].append({
                        'subject_id':
                        int(sbj_output_index),
                        'object_id':
                        int(obj_output_index),
                        'category_id':
                        int(rel_id),
                        'score':
                        float(rel_scores[rel_id])
                    })
        else:
            prd_vis_scores = prd_vis_scores.cpu().numpy()
            prd_bias_scores = prd_bias_scores.cpu().numpy()
            prd_bias_scores[:, 0] = 0
            prd_spt_scores = prd_spt_scores.cpu().numpy()

            prd_score = (prd_vis_scores + prd_spt_scores) * prd_bias_scores
            for i in range(prd_score.shape[0]):
                sbj_id = sbj_idxs[i]
                obj_id = obj_idxs[i]
                rel_scores = prd_score[i, :]

                sbj_score = float(det_bboxes[sbj_id, -1])
                obj_score = float(det_bboxes[obj_id, -1])

                if sbj_score < 0.4 or obj_score < 0.4:
                    continue

                obj_cat = ret_labels[obj_id] + 1
                rel_ids = np.where(rel_scores > thresh)[0]

                sbj_output_index = entity_index_to_output_index[sbj_id]
                obj_output_index = entity_index_to_output_index[obj_id]

                for rel_id in rel_ids:
                    ret['hoi_prediction'].append({
                        'subject_id':
                        int(sbj_output_index),
                        'object_id':
                        int(obj_output_index),
                        'category_id':
                        int(rel_id),
                        'score':
                        float(rel_scores[rel_id])
                    })

        return ret