Пример #1
0
    def forward(
        self,
        features,  # type: Dict[str, Tensor]
        proposals,  # type: List[Tensor]
        image_shapes,  # type: List[Tuple[int, int]]
        targets=None  # type: Optional[List[Dict[str, Tensor]]]
    ):
        # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
        """
        Arguments:
                        features (List[Tensor])
                        proposals (List[Tensor[N, 4]])
                        image_shapes (List[Tuple[H, W]])
                        targets (List[Dict])
        """
        if targets is not None:
            for t in targets:
                # TODO: https://github.com/pytorch/pytorch/issues/26731
                floating_point_types = (torch.float, torch.double, torch.half)
                assert t[
                    "boxes"].dtype in floating_point_types, 'target boxes must of float type'
                assert t[
                    "labels"].dtype == torch.int64, 'target labels must of int64 type'
        if self.training:
            proposals, matched_idxs, labels, regression_targets, data_sbj, data_obj, data_rlp = self.select_training_samples(
                proposals, targets)

            # faster_rcnn branch
            box_features = self.box_roi_pool(features, proposals, image_shapes)
            box_features = self.box_head(box_features)
            class_logits, box_regression = self.box_predictor(box_features)

            # predicate branch
            sbj_feat = self.box_roi_pool(features, data_sbj["proposals"],
                                         image_shapes)
            sbj_feat = self.box_head(sbj_feat)
            obj_feat = self.box_roi_pool(features, data_obj["proposals"],
                                         image_shapes)
            obj_feat = self.box_head(obj_feat)

            rel_feat = self.box_roi_pool(features, data_rlp["proposals"],
                                         image_shapes)
            rel_feat = self.rlp_head(rel_feat)

            concat_feat = torch.cat((sbj_feat, rel_feat, obj_feat), dim=1)

            sbj_cls_scores, obj_cls_scores, rlp_cls_scores = \
                self.RelDN(concat_feat, sbj_feat, obj_feat)

            result = torch.jit.annotate(List[Dict[str, torch.Tensor]], [])
            losses = {}

            assert labels is not None and regression_targets is not None

            loss_cls_sbj, accuracy_cls_sbj = reldn_losses(
                sbj_cls_scores, data_sbj["labels"])
            loss_cls_obj, accuracy_cls_obj = reldn_losses(
                obj_cls_scores, data_obj['labels'])
            loss_cls_rlp, accuracy_cls_rlp = reldn_losses(
                rlp_cls_scores, data_rlp['labels'])

            loss_classifier, loss_box_reg = fastrcnn_loss(
                class_logits, box_regression, labels, regression_targets)
            losses = {
                "loss_classifier": loss_classifier,
                "loss_box_reg": loss_box_reg,
                "loss_sbj": loss_cls_sbj,
                "acc_sbj": accuracy_cls_sbj.item(),
                "loss_obj": loss_cls_obj,
                "acc_obj": accuracy_cls_obj.item(),
                "loss_rlp": loss_cls_rlp,
                "acc_rlp": accuracy_cls_rlp.item()
            }

        else:
            labels = None
            regression_targets = None
            matched_idxs = None
            result = []

            # faster_rcnn branch
            box_features = self.box_roi_pool(features, proposals, image_shapes)
            box_features = self.box_head(box_features)
            class_logits, box_regression = self.box_predictor(box_features)

            boxes, scores, labels = self.postprocess_detections(
                class_logits, box_regression, proposals, image_shapes)
            num_images = len(boxes)

            all_sbj_boxes = []
            all_obj_boxes = []
            all_rlp_boxes = []
            all_shapes = []
            for img_id in range(num_images):
                sbj_inds = np.repeat(np.arange(boxes[img_id].shape[0]),
                                     boxes[img_id].shape[0])
                obj_inds = np.tile(np.arange(boxes[img_id].shape[0]),
                                   boxes[img_id].shape[0])

                sbj_inds, obj_inds = self.remove_self_pairs(sbj_inds, obj_inds)

                sbj_boxes = boxes[img_id][sbj_inds]
                obj_boxes = boxes[img_id][obj_inds]
                rlp_boxes = box_utils.boxes_union(sbj_boxes, obj_boxes)

                all_sbj_boxes.append(sbj_boxes)
                all_obj_boxes.append(obj_boxes)
                all_rlp_boxes.append(rlp_boxes)
                all_shapes.append(rlp_boxes.shape[0])

            # predicate branch
            sbj_feat = self.box_roi_pool(features, all_sbj_boxes, image_shapes)
            sbj_feat = self.box_head(sbj_feat)
            obj_feat = self.box_roi_pool(features, all_obj_boxes, image_shapes)
            obj_feat = self.box_head(obj_feat)
            rel_feat = self.box_roi_pool(features, all_rlp_boxes, image_shapes)
            rel_feat = self.rlp_head(rel_feat)
            concat_feat = torch.cat((sbj_feat, rel_feat, obj_feat), dim=1)
            sbj_cls_scores, obj_cls_scores, rlp_cls_scores = \
                self.RelDN(concat_feat, sbj_feat, obj_feat)

            sbj_cls_scores_list, obj_cls_scores_list, rlp_cls_scores_list = \
                sbj_cls_scores.split(all_shapes), obj_cls_scores.split(
                    all_shapes), rlp_cls_scores.split(all_shapes)

            for i, _ in enumerate(sbj_cls_scores_list):
                _, sbj_indices = torch.max(sbj_cls_scores_list[i], dim=1)
                _, obj_indices = torch.max(obj_cls_scores_list[i], dim=1)
                rel_scores, rel_indices = torch.max(rlp_cls_scores_list[i],
                                                    dim=1)
                # filter "unknown"
                mask = rel_indices > 0
                rel_scores = rel_scores[mask]
                predicates = rel_indices[mask]
                subjects = sbj_indices[mask]
                objects = obj_indices[mask]

                sbj_boxes = all_sbj_boxes[i][mask]
                obj_boxes = all_obj_boxes[i][mask]
                rlp_boxes = all_rlp_boxes[i][mask]

                score_mask = rel_scores > 0.4
                result = [{
                    "sbj_boxes": sbj_boxes[score_mask],
                    "obj_boxes": obj_boxes[score_mask],
                    'sbj_labels': subjects[score_mask],
                    'obj_labels': objects[score_mask],
                    'predicates': predicates[score_mask],
                }]
                # result = [{"sbj_boxes": sbj_boxes,
                #            "obj_boxes": obj_boxes,
                #            'sbj_labels': subjects,
                #            'obj_labels': objects,
                #            'predicates': predicates,
                #            }]
                losses = {}

        return result, losses
Пример #2
0
    def select_training_samples(
        self,
        proposals,  # type: List[Tensor]
        # type: Optional[List[Dict[str, Tensor]]]
        targets):
        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
        self.check_targets(targets)
        assert targets is not None
        dtype = proposals[0].dtype
        device = proposals[0].device

        # shape  --> list of [Tensor of size 10,2,4]
        gt_boxes = [t["boxes"].to(dtype) for t in targets]
        # shape  --> list of [Tensor of size 10,2]
        gt_labels = [t["labels"] for t in targets]
        gt_preds = [t["preds"] for t in targets]

        # append ground-truth bboxes to propos
        proposals = self.add_gt_proposals(proposals, gt_boxes)

        # get matching gt indices for each proposal
        matched_idxs, labels = self.assign_targets_to_proposals(
            proposals, gt_boxes, gt_labels, assign_to="all")
        # sample a fixed proportion of positive-negative proposals
        sampled_inds = self.subsample(labels, sample_for="all")  # size 512
        all_proposals = proposals.copy()
        matched_gt_boxes = []
        num_images = len(proposals)
        for img_id in range(num_images):
            img_sampled_inds = sampled_inds[img_id]
            all_proposals[img_id] = proposals[img_id][img_sampled_inds]
            labels[img_id] = labels[img_id][img_sampled_inds]
            matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]

            gt_boxes_in_image = gt_boxes[img_id].view(-1, 4)
            if gt_boxes_in_image.numel() == 0:
                gt_boxes_in_image = torch.zeros((1, 4),
                                                dtype=dtype,
                                                device=device)
            matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])

        regression_targets = self.box_coder.encode(matched_gt_boxes,
                                                   all_proposals)

        # get matching gt indices for each proposal
        _, sbj_labels = self.assign_targets_to_proposals(proposals,
                                                         gt_boxes,
                                                         gt_labels,
                                                         assign_to="subject")
        sampled_inds = self.subsample(
            sbj_labels, sample_for="subject")  # 64 --> 32 pos, 32 neg
        sbj_proposals = proposals.copy()
        for img_id in range(num_images):
            img_sampled_inds = sampled_inds[img_id]
            sbj_proposals[img_id] = proposals[img_id][img_sampled_inds]
            sbj_labels[img_id] = sbj_labels[img_id][img_sampled_inds]
        pos_sbj_labels, pos_sbj_proposals = self.extract_positive_proposals(
            sbj_labels, sbj_proposals)

        # get matching gt indices for each proposal
        _, obj_labels = self.assign_targets_to_proposals(proposals,
                                                         gt_boxes,
                                                         gt_labels,
                                                         assign_to="objects")
        sampled_inds = self.subsample(
            obj_labels, sample_for="object")  # 64 --> 32 pos, 32 neg
        obj_proposals = proposals.copy()
        for img_id in range(num_images):
            img_sampled_inds = sampled_inds[img_id]
            obj_proposals[img_id] = proposals[img_id][img_sampled_inds]
            obj_labels[img_id] = obj_labels[img_id][img_sampled_inds]
        pos_obj_labels, pos_obj_proposals = self.extract_positive_proposals(
            obj_labels, obj_proposals)

        # prepare relation proposals
        rlp_proposals = []
        for img_id in range(num_images):
            sbj_shape = pos_sbj_labels[img_id].shape[0]
            obj_shape = pos_obj_labels[img_id].shape[0]
            sbj_inds = np.repeat(np.arange(sbj_shape), obj_shape)
            obj_inds = np.tile(np.arange(obj_shape), sbj_shape)

            pos_sbj_labels[img_id] = pos_sbj_labels[img_id][sbj_inds]
            pos_obj_labels[img_id] = pos_obj_labels[img_id][obj_inds]
            pos_sbj_proposals[img_id] = pos_sbj_proposals[img_id][sbj_inds]
            pos_obj_proposals[img_id] = pos_obj_proposals[img_id][obj_inds]

            rlp_proposals.append(
                box_utils.boxes_union(pos_obj_proposals[img_id],
                                      pos_sbj_proposals[img_id]))

        # assign gt_predicate to relation proposals
        rlp_labels = self.assign_pred_to_rlp_proposals(pos_sbj_proposals,
                                                       pos_obj_proposals,
                                                       gt_boxes, gt_labels,
                                                       gt_preds)

        # 128 --> 64 pos, 64 neg)
        sampled_inds = self.subsample(rlp_labels, sample_for="rel")
        for img_id in range(num_images):
            img_sampled_inds = sampled_inds[img_id]
            pos_sbj_proposals[img_id] = pos_sbj_proposals[img_id][
                img_sampled_inds]
            pos_obj_proposals[img_id] = pos_obj_proposals[img_id][
                img_sampled_inds]
            rlp_proposals[img_id] = rlp_proposals[img_id][img_sampled_inds]
            pos_sbj_labels[
                img_id] = pos_sbj_labels[img_id][img_sampled_inds] - 1
            pos_obj_labels[
                img_id] = pos_obj_labels[img_id][img_sampled_inds] - 1
            rlp_labels[img_id] = rlp_labels[img_id][img_sampled_inds]

        data_sbj = {'proposals': pos_sbj_proposals, 'labels': pos_sbj_labels}
        data_obj = {'proposals': pos_obj_proposals, 'labels': pos_obj_labels}
        data_rlp = {'proposals': rlp_proposals, 'labels': rlp_labels}

        return all_proposals, matched_idxs, labels, regression_targets, data_sbj, data_obj, data_rlp
Пример #3
0
def eval_rel_results(all_results, output_dir, do_val):

    if do_val:
        if cfg.TEST.DATASETS[0].find('vg') >= 0:
            prd_k_set = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20)
        elif cfg.TEST.DATASETS[0].find('vrd') >= 0:
            prd_k_set = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 70)
        else:
            prd_k_set = (1, 2, 3, 4, 5, 6, 7, 8, 9)

        if cfg.TEST.DATASETS[0].find('vg') >= 0:
            eval_sets = (False, )
        else:
            eval_sets = (False, True)

        for phrdet in eval_sets:
            eval_metric = 'phrdet' if phrdet else 'reldet'
            print('{}:'.format(eval_metric))

            for prd_k in prd_k_set:
                print('prd_k = {}:'.format(prd_k))

                recalls = {20: 0, 50: 0, 100: 0}
                all_gt_cnt = 0
                topk_dets = []
                for im_i, res in enumerate(tqdm(all_results)):
                    # in oi_all_rel some images have no dets
                    if res['prd_scores'] is None:
                        det_boxes_s_top = np.zeros((0, 4), dtype=np.float32)
                        det_boxes_o_top = np.zeros((0, 4), dtype=np.float32)
                        det_labels_s_top = np.zeros(0, dtype=np.int32)
                        det_labels_p_top = np.zeros(0, dtype=np.int32)
                        det_labels_o_top = np.zeros(0, dtype=np.int32)
                        det_scores_top = np.zeros(0, dtype=np.float32)
                    else:
                        det_boxes_sbj = res['sbj_boxes']  # (#num_rel, 4)
                        det_boxes_obj = res['obj_boxes']  # (#num_rel, 4)
                        det_labels_sbj = res['sbj_labels']  # (#num_rel,)
                        det_labels_obj = res['obj_labels']  # (#num_rel,)
                        det_scores_sbj = res['sbj_scores']  # (#num_rel,)
                        det_scores_obj = res['obj_scores']  # (#num_rel,)
                        det_scores_prd = res['prd_scores'][:, 1:]

                        det_labels_prd = np.argsort(-det_scores_prd, axis=1)
                        det_scores_prd = -np.sort(-det_scores_prd, axis=1)

                        det_scores_so = det_scores_sbj * det_scores_obj
                        det_scores_spo = det_scores_so[:,
                                                       None] * det_scores_prd[:, :
                                                                              prd_k]
                        # det_scores_spo = det_scores_prd[:, :prd_k]
                        det_scores_inds = argsort_desc(det_scores_spo)[:topk]
                        det_scores_top = det_scores_spo[det_scores_inds[:, 0],
                                                        det_scores_inds[:, 1]]
                        det_boxes_so_top = np.hstack(
                            (det_boxes_sbj[det_scores_inds[:, 0]],
                             det_boxes_obj[det_scores_inds[:, 0]]))
                        det_labels_p_top = det_labels_prd[det_scores_inds[:,
                                                                          0],
                                                          det_scores_inds[:,
                                                                          1]]
                        det_labels_spo_top = np.vstack(
                            (det_labels_sbj[det_scores_inds[:, 0]],
                             det_labels_p_top,
                             det_labels_obj[det_scores_inds[:,
                                                            0]])).transpose()

                        det_boxes_s_top = det_boxes_so_top[:, :4]
                        det_boxes_o_top = det_boxes_so_top[:, 4:]
                        det_labels_s_top = det_labels_spo_top[:, 0]
                        det_labels_p_top = det_labels_spo_top[:, 1]
                        det_labels_o_top = det_labels_spo_top[:, 2]

                    topk_dets.append(
                        dict(image=res['image'],
                             det_boxes_s_top=det_boxes_s_top,
                             det_boxes_o_top=det_boxes_o_top,
                             det_labels_s_top=det_labels_s_top,
                             det_labels_p_top=det_labels_p_top,
                             det_labels_o_top=det_labels_o_top,
                             det_scores_top=det_scores_top))

                    gt_boxes_sbj = res['gt_sbj_boxes']  # (#num_gt, 4)
                    gt_boxes_obj = res['gt_obj_boxes']  # (#num_gt, 4)
                    gt_labels_sbj = res['gt_sbj_labels']  # (#num_gt,)
                    gt_labels_obj = res['gt_obj_labels']  # (#num_gt,)
                    gt_labels_prd = res['gt_prd_labels']  # (#num_gt,)
                    gt_boxes_so = np.hstack((gt_boxes_sbj, gt_boxes_obj))
                    gt_labels_spo = np.vstack((gt_labels_sbj, gt_labels_prd,
                                               gt_labels_obj)).transpose()
                    # Compute recall. It's most efficient to match once and then do recall after
                    # det_boxes_so_top is (#num_rel, 8)
                    # det_labels_spo_top is (#num_rel, 3)
                    if phrdet:
                        det_boxes_r_top = boxes_union(det_boxes_s_top,
                                                      det_boxes_o_top)
                        gt_boxes_r = boxes_union(gt_boxes_sbj, gt_boxes_obj)
                        pred_to_gt = _compute_pred_matches(gt_labels_spo,
                                                           det_labels_spo_top,
                                                           gt_boxes_r,
                                                           det_boxes_r_top,
                                                           phrdet=phrdet)
                    else:
                        pred_to_gt = _compute_pred_matches(gt_labels_spo,
                                                           det_labels_spo_top,
                                                           gt_boxes_so,
                                                           det_boxes_so_top,
                                                           phrdet=phrdet)
                    all_gt_cnt += gt_labels_spo.shape[0]
                    for k in recalls:
                        if len(pred_to_gt):
                            match = reduce(np.union1d, pred_to_gt[:k])
                        else:
                            match = []
                        recalls[k] += len(match)

                    topk_dets[-1].update(
                        dict(gt_boxes_sbj=gt_boxes_sbj,
                             gt_boxes_obj=gt_boxes_obj,
                             gt_labels_sbj=gt_labels_sbj,
                             gt_labels_obj=gt_labels_obj,
                             gt_labels_prd=gt_labels_prd))
                for k in recalls:
                    recalls[k] = float(
                        recalls[k]) / (float(all_gt_cnt) + 1e-12)
                print_stats(recalls)
    else:
        prd_k = 2
        topk_dets = []
        for im_i, res in enumerate(tqdm(all_results)):
            # in oi_all_rel some images have no dets
            if res['prd_scores'] is None:
                det_boxes_s_top = np.zeros((0, 4), dtype=np.float32)
                det_boxes_o_top = np.zeros((0, 4), dtype=np.float32)
                det_labels_s_top = np.zeros(0, dtype=np.int32)
                det_labels_p_top = np.zeros(0, dtype=np.int32)
                det_labels_o_top = np.zeros(0, dtype=np.int32)
                det_scores_top = np.zeros(0, dtype=np.float32)
            else:
                det_boxes_sbj = res['sbj_boxes']  # (#num_rel, 4)
                det_boxes_obj = res['obj_boxes']  # (#num_rel, 4)
                det_labels_sbj = res['sbj_labels']  # (#num_rel,)
                det_labels_obj = res['obj_labels']  # (#num_rel,)
                det_scores_sbj = res['sbj_scores']  # (#num_rel,)
                det_scores_obj = res['obj_scores']  # (#num_rel,)
                det_scores_prd = res['prd_scores'][:, 1:]

                det_labels_prd = np.argsort(-det_scores_prd, axis=1)
                det_scores_prd = -np.sort(-det_scores_prd, axis=1)

                det_scores_so = det_scores_sbj * det_scores_obj
                det_scores_spo = det_scores_so[:, None] * det_scores_prd[:, :
                                                                         prd_k]
                # det_scores_spo = det_scores_prd[:, :prd_k]
                det_scores_inds = argsort_desc(det_scores_spo)[:topk]
                det_scores_top = det_scores_spo[det_scores_inds[:, 0],
                                                det_scores_inds[:, 1]]
                det_boxes_so_top = np.hstack(
                    (det_boxes_sbj[det_scores_inds[:, 0]],
                     det_boxes_obj[det_scores_inds[:, 0]]))
                det_labels_p_top = det_labels_prd[det_scores_inds[:, 0],
                                                  det_scores_inds[:, 1]]
                det_labels_spo_top = np.vstack(
                    (det_labels_sbj[det_scores_inds[:, 0]], det_labels_p_top,
                     det_labels_obj[det_scores_inds[:, 0]])).transpose()

                det_boxes_s_top = det_boxes_so_top[:, :4]
                det_boxes_o_top = det_boxes_so_top[:, 4:]
                det_labels_s_top = det_labels_spo_top[:, 0]
                det_labels_p_top = det_labels_spo_top[:, 1]
                det_labels_o_top = det_labels_spo_top[:, 2]

            topk_dets.append(
                dict(image=res['image'],
                     det_boxes_s_top=det_boxes_s_top,
                     det_boxes_o_top=det_boxes_o_top,
                     det_labels_s_top=det_labels_s_top,
                     det_labels_p_top=det_labels_p_top,
                     det_labels_o_top=det_labels_o_top,
                     det_scores_top=det_scores_top))
        print('Saving topk dets...')
        topk_dets_f = os.path.join(output_dir, 'rel_detections_topk.pkl')
        with open(topk_dets_f, 'wb') as f:
            pickle.dump(topk_dets, f, pickle.HIGHEST_PROTOCOL)
        logger.info('topk_dets size: {}'.format(len(topk_dets)))
        print('Done.')
def eval_rel_results(all_results, output_dir, do_val):

    if cfg.TEST.DATASETS[0].find('vg') >= 0:
        prd_k_set = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20)
    elif cfg.TEST.DATASETS[0].find('vrd') >= 0:
        prd_k_set = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 70)
    else:
        prd_k_set = (1, 2, 3, 4, 5, 6, 7, 8, 9)

    if cfg.TEST.DATASETS[0].find('vg') >= 0:
        eval_sets = (False, )
    else:
        eval_sets = (False, True)

    prd_k_set = (1, 2)
    for phrdet in eval_sets:
        eval_metric = 'phrdet' if phrdet else 'reldet'
        print('{}:'.format(eval_metric))

        for prd_k in prd_k_set:
            print('prd_k = {}:'.format(prd_k))

            recalls = {20: 0, 50: 0, 100: 0}
            if do_val:
                all_gt_cnt = 0

            topk_dets = []
            for im_i, res in enumerate(tqdm(all_results)):

                # in oi_all_rel some images have no dets
                if res['prd_scores'] is None:
                    det_boxes_s_top = np.zeros((0, 4), dtype=np.float32)
                    det_boxes_o_top = np.zeros((0, 4), dtype=np.float32)
                    det_labels_s_top = np.zeros(0, dtype=np.int32)
                    det_labels_p_top = np.zeros(0, dtype=np.int32)
                    det_labels_o_top = np.zeros(0, dtype=np.int32)
                    det_scores_top = np.zeros(0, dtype=np.float32)
                else:
                    out_dict = {}
                    det_boxes_sbj = res['sbj_boxes']  # (#num_rel, 4)
                    #print('det_boxes_sbj:')
                    #print(det_boxes_sbj)
                    out_dict['det_boxes_sbj'] = det_boxes_sbj
                    det_boxes_obj = res['obj_boxes']  # (#num_rel, 4)
                    #print('det_boxes_obj:')
                    #print(det_boxes_obj)
                    out_dict['det_boxes_obj'] = det_boxes_obj
                    det_labels_sbj = res['sbj_labels']  # (#num_rel,)
                    #print('det_labels_sbj:')
                    #print(det_labels_sbj)
                    out_dict['det_labels_sbj'] = det_labels_sbj
                    det_labels_obj = res['obj_labels']  # (#num_rel,)
                    #print('det_labels_obj:')
                    #print(det_labels_obj)
                    out_dict['det_labels_obj'] = det_labels_obj
                    det_scores_sbj = res['sbj_scores']  # (#num_rel,)
                    #print('det_scores_sbj:')
                    #print(det_scores_sbj)
                    out_dict['det_scores_sbj'] = det_scores_sbj
                    det_scores_obj = res['obj_scores']  # (#num_rel,)
                    #print('det_scores_obj:')
                    #print(det_scores_obj)
                    out_dict['det_scores_obj'] = det_scores_obj
                    det_scores_prd = res['prd_scores'][:, 1:]
                    #print('det_scores_prd:')
                    #print(det_scores_prd)
                    out_dict['det_scores_prd'] = det_scores_prd

                    det_labels_prd = np.argsort(-det_scores_prd, axis=1)
                    det_scores_prd = -np.sort(-det_scores_prd, axis=1)

                    det_scores_so = det_scores_sbj * det_scores_obj
                    det_scores_spo = det_scores_so[:,
                                                   None] * det_scores_prd[:, :
                                                                          prd_k]
                    # det_scores_spo = det_scores_prd[:, :prd_k]
                    det_scores_inds = argsort_desc(det_scores_spo)[:topk]
                    det_scores_top = det_scores_spo[det_scores_inds[:, 0],
                                                    det_scores_inds[:, 1]]
                    det_boxes_so_top = np.hstack(
                        (det_boxes_sbj[det_scores_inds[:, 0]],
                         det_boxes_obj[det_scores_inds[:, 0]]))
                    det_labels_p_top = det_labels_prd[det_scores_inds[:, 0],
                                                      det_scores_inds[:, 1]]
                    det_labels_spo_top = np.vstack(
                        (det_labels_sbj[det_scores_inds[:,
                                                        0]], det_labels_p_top,
                         det_labels_obj[det_scores_inds[:, 0]])).transpose()

                    #                     cand_inds = np.where(det_scores_top > cfg.TEST.SPO_SCORE_THRESH)[0]
                    #                     det_boxes_so_top = det_boxes_so_top[cand_inds]
                    #                     det_labels_spo_top = det_labels_spo_top[cand_inds]
                    #                     det_scores_top = det_scores_top[cand_inds]

                    det_boxes_s_top = det_boxes_so_top[:, :4]
                    det_boxes_o_top = det_boxes_so_top[:, 4:]
                    det_labels_s_top = det_labels_spo_top[:, 0]
                    det_labels_p_top = det_labels_spo_top[:, 1]
                    det_labels_o_top = det_labels_spo_top[:, 2]
                    out_dict = {}
                    out_dict['boxes_s_top'] = det_boxes_s_top
                    out_dict['boxes_o_top'] = det_boxes_o_top
                    out_dict['labels_s_top'] = det_labels_s_top
                    out_dict['labels_p_top'] = det_labels_p_top
                    out_dict['labels_o_top'] = det_labels_o_top
                    out_dict['scores_top'] = det_scores_top
                    out_dict['image'] = res['image']
                    with open('VRD-{}-{}.pkl'.format(im_i, prd_k),
                              'wb') as fout:
                        pickle.dump(out_dict, fout, pickle.HIGHEST_PROTOCOL)

                topk_dets.append(
                    dict(image=res['image'],
                         det_boxes_s_top=det_boxes_s_top,
                         det_boxes_o_top=det_boxes_o_top,
                         det_labels_s_top=det_labels_s_top,
                         det_labels_p_top=det_labels_p_top,
                         det_labels_o_top=det_labels_o_top,
                         det_scores_top=det_scores_top))

                if do_val:
                    gt_boxes_sbj = res['gt_sbj_boxes']  # (#num_gt, 4)
                    gt_boxes_obj = res['gt_obj_boxes']  # (#num_gt, 4)
                    gt_labels_sbj = res['gt_sbj_labels']  # (#num_gt,)
                    gt_labels_obj = res['gt_obj_labels']  # (#num_gt,)
                    gt_labels_prd = res['gt_prd_labels']  # (#num_gt,)
                    gt_boxes_so = np.hstack((gt_boxes_sbj, gt_boxes_obj))
                    gt_labels_spo = np.vstack((gt_labels_sbj, gt_labels_prd,
                                               gt_labels_obj)).transpose()
                    # Compute recall. It's most efficient to match once and then do recall after
                    # det_boxes_so_top is (#num_rel, 8)
                    # det_labels_spo_top is (#num_rel, 3)
                    if phrdet:
                        det_boxes_r_top = boxes_union(det_boxes_s_top,
                                                      det_boxes_o_top)
                        gt_boxes_r = boxes_union(gt_boxes_sbj, gt_boxes_obj)
                        pred_to_gt = _compute_pred_matches(gt_labels_spo,
                                                           det_labels_spo_top,
                                                           gt_boxes_r,
                                                           det_boxes_r_top,
                                                           phrdet=phrdet)
                    else:
                        pred_to_gt = _compute_pred_matches(gt_labels_spo,
                                                           det_labels_spo_top,
                                                           gt_boxes_so,
                                                           det_boxes_so_top,
                                                           phrdet=phrdet)
                    all_gt_cnt += gt_labels_spo.shape[0]
                    for k in recalls:
                        if len(pred_to_gt):
                            match = reduce(np.union1d, pred_to_gt[:k])
                        else:
                            match = []
                        recalls[k] += len(match)

                    topk_dets[-1].update(
                        dict(gt_boxes_sbj=gt_boxes_sbj,
                             gt_boxes_obj=gt_boxes_obj,
                             gt_labels_sbj=gt_labels_sbj,
                             gt_labels_obj=gt_labels_obj,
                             gt_labels_prd=gt_labels_prd))

            if do_val:
                for k in recalls:
                    recalls[k] = float(
                        recalls[k]) / (float(all_gt_cnt) + 1e-12)
                print_stats(recalls)
def prepare_mAP_dets(topk_dets, cls_num):
    cls_image_ids = [[] for _ in range(cls_num)]
    cls_dets = [{
        'confidence': np.empty(0),
        'BB_s': np.empty((0, 4)),
        'BB_o': np.empty((0, 4)),
        'BB_r': np.empty((0, 4)),
        'LBL_s': np.empty(0),
        'LBL_o': np.empty(0)
    } for _ in range(cls_num)]
    cls_gts = [{} for _ in range(cls_num)]
    npos = [0 for _ in range(cls_num)]
    for dets in tqdm(topk_dets):
        image_id = dets['image'].split('/')[-1].split('.')[0]
        sbj_boxes = dets['det_boxes_s_top']
        obj_boxes = dets['det_boxes_o_top']
        rel_boxes = boxes_union(sbj_boxes, obj_boxes)
        sbj_labels = dets['det_labels_s_top']
        obj_labels = dets['det_labels_o_top']
        prd_labels = dets['det_labels_p_top']
        det_scores = dets['det_scores_top']
        gt_boxes_sbj = dets['gt_boxes_sbj']
        gt_boxes_obj = dets['gt_boxes_obj']
        gt_boxes_rel = boxes_union(gt_boxes_sbj, gt_boxes_obj)
        gt_labels_sbj = dets['gt_labels_sbj']
        gt_labels_prd = dets['gt_labels_prd']
        gt_labels_obj = dets['gt_labels_obj']
        for c in range(cls_num):
            cls_inds = np.where(prd_labels == c)[0]
            if len(cls_inds):
                cls_sbj_boxes = sbj_boxes[cls_inds]
                cls_obj_boxes = obj_boxes[cls_inds]
                cls_rel_boxes = rel_boxes[cls_inds]
                cls_sbj_labels = sbj_labels[cls_inds]
                cls_obj_labels = obj_labels[cls_inds]
                cls_det_scores = det_scores[cls_inds]
                cls_dets[c]['confidence'] = np.concatenate(
                    (cls_dets[c]['confidence'], cls_det_scores))
                cls_dets[c]['BB_s'] = np.concatenate(
                    (cls_dets[c]['BB_s'], cls_sbj_boxes), 0)
                cls_dets[c]['BB_o'] = np.concatenate(
                    (cls_dets[c]['BB_o'], cls_obj_boxes), 0)
                cls_dets[c]['BB_r'] = np.concatenate(
                    (cls_dets[c]['BB_r'], cls_rel_boxes), 0)
                cls_dets[c]['LBL_s'] = np.concatenate(
                    (cls_dets[c]['LBL_s'], cls_sbj_labels))
                cls_dets[c]['LBL_o'] = np.concatenate(
                    (cls_dets[c]['LBL_o'], cls_obj_labels))
                cls_image_ids[c] += [image_id] * len(cls_inds)
            cls_gt_inds = np.where(gt_labels_prd == c)[0]
            cls_gt_boxes_sbj = gt_boxes_sbj[cls_gt_inds]
            cls_gt_boxes_obj = gt_boxes_obj[cls_gt_inds]
            cls_gt_boxes_rel = gt_boxes_rel[cls_gt_inds]
            cls_gt_labels_sbj = gt_labels_sbj[cls_gt_inds]
            cls_gt_labels_obj = gt_labels_obj[cls_gt_inds]
            cls_gt_num = len(cls_gt_inds)
            det = [False] * cls_gt_num
            npos[c] = npos[c] + cls_gt_num
            cls_gts[c][image_id] = {
                'gt_boxes_sbj': cls_gt_boxes_sbj,
                'gt_boxes_obj': cls_gt_boxes_obj,
                'gt_boxes_rel': cls_gt_boxes_rel,
                'gt_labels_sbj': cls_gt_labels_sbj,
                'gt_labels_obj': cls_gt_labels_obj,
                'gt_num': cls_gt_num,
                'det': det
            }
    return cls_image_ids, cls_dets, cls_gts, npos