Пример #1
0
def ohw_mask2boxlist(ohw_mask):
    """
    binary masks of all object, 1 image, to boxlist / rois
    arguments:
        ohw_mask: tensor, OHW, Nobjs binary mask 
    return:
        list_ref_box:
        template_valid: tensor, O 
    """
    O,H,W = ohw_mask.shape
    boxes_xyxy_o = []
    # check if template_valid is valid, by check if the sum is zero  
    template_valid = (ohw_mask.sum(2).sum(1) > 0).long()
    for o in range(O):
        bin_mask = ohw_mask[o]
        boxes_xyxy = binmask_to_bbox_xyxy_pt(bin_mask)
        if len(boxes_xyxy) == 0:
            boxes_xyxy = [0,0,W-1,H-1]
        #check_boxes_xyxy = binmask_to_bbox_xyxy(bin_mask.cpu().numpy())
        #assert(boxes_xyxy == check_boxes_xyxy), '{} {}'.format(boxes_xyxy, check_boxes_xyxy)
        boxes_xyxy_o.append(boxes_xyxy)

    list_ref_box = BoxList(boxes_xyxy_o, (W,H), mode='xyxy')
    list_ref_box.add_field('mask', ohw_mask)
    list_ref_box = list_ref_box.to(ohw_mask.device)
    scores = ohw_mask.new_zeros((O)) + 1
    list_ref_box.add_field('scores', scores)
    template_valid.requires_grad = False
    #if template_valid.sum().item() > 0:
    #    template_valid[0,0] = 1 # at least one template 

    return list_ref_box, template_valid
Пример #2
0
    def get_default_boxlist(boxlist: BoxList, bboxes, ids=None, labels=None):
        """
        Construct a boxlist with bbox as bboxes,
        all other fields to be default
        id -1, label -1
        """
        device = boxlist.bbox.device
        num_boxes = bboxes.shape[0]
        if ids is None:
            ids = torch.zeros((num_boxes, )) - 1.
        if labels is None:
            labels = torch.zeros((num_boxes, )) - 1.

        default_boxlist = BoxList(bboxes, image_size=boxlist.size, mode='xyxy')
        default_boxlist.add_field('labels', labels)
        default_boxlist.add_field('ids', ids)

        return default_boxlist.to(device)
Пример #3
0
def eval_relation(
        dataset: VGDataset,
        predictions: [RelationTriplet],  # list of RelationTriplet
        output_folder):
    logger = logging.getLogger(__name__)
    rel_total_cnt = 0

    relaion_hit_cnt = torch.zeros((2), dtype=torch.int32)  # top50 and 100
    phrase_hit_num = torch.zeros((2), dtype=torch.int32)
    rel_loc_hit_cnt = torch.zeros((2), dtype=torch.int32)
    rel_inst_hit_cnt = torch.zeros((2), dtype=torch.int32)
    instance_det_hit_num = torch.zeros((2), dtype=torch.int32)

    eval_topks = cfg.MODEL.RELATION.TOPK_TRIPLETS

    cuda_dev = torch.zeros((1, 1)).cuda().device
    logger.info("start relationship evaluations. ")
    logger.info("relation static range %s" % str(eval_topks))
    true_det_rel = []

    det_total = 0

    relation_eval_res = {}
    for indx, rel_pred in tqdm(enumerate(predictions)):
        # rel_pred is a RelationTriplet obj
        # ipdb.set_trace()

        original_id = dataset.id_to_img_map[indx]
        img_info = dataset.get_img_info(indx)
        image_width = img_info["width"]
        image_height = img_info["height"]
        rel_pred.instance = rel_pred.instance.resize(
            (image_width, image_height))
        # get the boxes
        ann_ids = dataset.coco.getAnnIds(imgIds=original_id)
        anno = dataset.coco.loadAnns(ann_ids)
        gt_boxes = [obj["bbox"] for obj in anno if obj["iscrowd"] == 0]
        det_total += len(gt_boxes)

        labels = [obj["category_id"] for obj in anno]
        # get gt boxes
        gt_boxes = torch.as_tensor(gt_boxes).reshape(
            -1, 4)  # guard against no boxes
        gt_boxes = BoxList(gt_boxes, (image_width, image_height),
                           mode="xywh").convert("xyxy")
        gt_boxes.add_field("labels", torch.LongTensor(labels))
        gt_boxes = gt_boxes.to(cuda_dev)
        rel_pred = rel_pred.to(cuda_dev)

        # get gt relations
        gt_relations = torch.as_tensor(dataset.relationships[original_id])
        gt_relations = gt_relations.to(cuda_dev)
        rel_total_cnt += gt_relations.shape[0]

        for i, topk in enumerate(eval_topks):
            selected_rel_pred = rel_pred[:topk]
            # fetch the iou rate of  gt boxes and det res pairs
            instance_hit_iou = boxlist_iou(selected_rel_pred.instance,
                                           gt_boxes)
            if len(instance_hit_iou) == 0:
                continue
            max_iou_val, inst_loc_hit_idx = torch.max(instance_hit_iou, dim=1)

            # box pair location hit
            inst_det_hit_idx = inst_loc_hit_idx.clone().detach()
            neg_loc_hit_idx = (max_iou_val < 0.5)
            inst_loc_hit_idx[
                neg_loc_hit_idx] = -1  # we set the det result that not hit as -1

            # box pair and cate hit
            neg_det_hit_idx = neg_loc_hit_idx | \
                              (selected_rel_pred.instance.get_field("labels") != gt_boxes.get_field("labels")[
                                  inst_det_hit_idx])

            inst_det_hit_idx[
                neg_det_hit_idx] = -1  # set the det result not hit as -1
            instance_det_hit_num[i] += len(
                torch.unique(inst_det_hit_idx[inst_det_hit_idx != -1]))

            # check the hit of each triplets in gt rel set
            rel_pair_mat = -torch.ones((selected_rel_pred.pair_mat.shape),
                                       dtype=torch.int64,
                                       device=cuda_dev)
            # instances box location hit res
            rel_loc_pair_mat = -torch.ones((selected_rel_pred.pair_mat.shape),
                                           dtype=torch.int64,
                                           device=cuda_dev)
            # instances box location and category hit
            rel_det_pair_mat = -torch.ones((selected_rel_pred.pair_mat.shape),
                                           dtype=torch.int64,
                                           device=cuda_dev)
            hit_rel_idx_collect = []
            for idx, gt_rel in enumerate(gt_relations):
                # write result into the pair mat
                # ipdb.set_trace()
                rel_pair_mat[:, 0] = inst_det_hit_idx[
                    selected_rel_pred.pair_mat[:, 0]]
                rel_pair_mat[:, 1] = inst_det_hit_idx[
                    selected_rel_pred.pair_mat[:, 1]]

                rel_loc_pair_mat[:, 0] = inst_loc_hit_idx[
                    selected_rel_pred.pair_mat[:, 0]]
                rel_loc_pair_mat[:, 1] = inst_loc_hit_idx[
                    selected_rel_pred.pair_mat[:, 1]]

                rel_det_pair_mat[:, 0] = inst_det_hit_idx[
                    selected_rel_pred.pair_mat[:, 0]]
                rel_det_pair_mat[:, 1] = inst_det_hit_idx[
                    selected_rel_pred.pair_mat[:, 1]]

                rel_hit_res = rel_pair_mat.eq(gt_rel[:2])
                rel_hit_idx = torch.nonzero((rel_hit_res.sum(dim=1) >= 2) & (
                    selected_rel_pred.phrase_l == gt_rel[-1]))

                rel_pair_loc_res = rel_loc_pair_mat.eq(gt_rel[:2])
                rel_loc_hit_idx = torch.nonzero(
                    (rel_pair_loc_res.sum(dim=1) >= 2))

                rel_inst_hit_res = rel_det_pair_mat.eq(gt_rel[:2])
                rel_inst_hit_idx = torch.nonzero(
                    (rel_inst_hit_res.sum(dim=1) >= 2))

                phrase_hit_idx = torch.nonzero(
                    selected_rel_pred.phrase_l == gt_rel[-1])

                if len(rel_hit_idx) >= 1:
                    relaion_hit_cnt[i] += 1
                if len(rel_loc_hit_idx) >= 1:
                    rel_loc_hit_cnt[i] += 1
                if len(rel_inst_hit_idx) >= 1:
                    rel_inst_hit_cnt[i] += 1
                if len(phrase_hit_idx) >= 1:
                    phrase_hit_num[i] += 1

            #     hit_rel_idx_collect.append(rel_hit_idx)
            # hit_rel_pair_id = torch.cat(hit_rel_idx_collect).cpu()
            # rel_pred_save = rel_pred.to(hit_rel_pair_id.device)
            # true_det_rel.append((rel_pred_save, hit_rel_pair_id))

    # summarize result
    all_text_res = ''
    for i, topk in enumerate(eval_topks):
        relation_eval_res['relation Recall@%d' % topk] = {
            'relation': relaion_hit_cnt[i].item() / rel_total_cnt,
            "phrase_cls": phrase_hit_num[i].item() / rel_total_cnt,
            "inst_pair_loc": rel_loc_hit_cnt[i].item() / rel_total_cnt,
            "inst_pair_cls": rel_inst_hit_cnt[i].item() / rel_total_cnt,
            "det": instance_det_hit_num[i].item() / det_total
        }

        txt_res = 'Relation detecion Recall@%d \n' % topk \
                  + "instances location pair: {inst_pair_loc}\n" \
                    "instances detection pair: {inst_pair_cls} \n" \
                    "phrase cls: {phrase_cls} \n" \
                    "relation: {relation}\n" \
                    "detection: {det}\n".format(**relation_eval_res['relation Recall@%d'
                                                                    % topk])

        logger.info(txt_res)
        all_text_res += txt_res
    if output_folder:
        import json
        # torch.save(true_det_rel, os.path.join(output_folder, "relation_det_results.pth"))
        with open(os.path.join(output_folder, 'rel_eval_res.txt'), 'w') as f:
            f.write(json.dumps(relation_eval_res, indent=3))

    # todo visualization

    return relation_eval_res

    pass
Пример #4
0
    def forward(self, images, targets=None, auxiliary_task=False):
        """
        Arguments:
            images (list[Tensor] or ImageList): images to be processed
            targets (list[BoxList]): ground-truth boxes present in the image (optional)
            auxiliary_task (Bool): if the auxiliary task is enabled during training

        Returns:
            result (list[BoxList] or dict[Tensor]): the output from the model.
                During training, it returns a dict[Tensor] which contains the losses.
                During testing, it returns list[BoxList] contains additional fields
                like `scores`, `labels` and `mask` (for Mask R-CNN models).
        """

        if self.training and targets is None and not auxiliary_task:
            raise ValueError("In training mode, targets should be passed")

        if not self.training and auxiliary_task:
            raise ValueError("Cannot enable auxiliary task at test time")
        images = to_image_list(images)
        features = self.backbone(images.tensors)

        if auxiliary_task and self.cfg.MODEL.SELF_SUPERVISOR.TYPE == "rotation":
            # self._log_image_tensorboard(images, targets, 5)
            straight_features = features
            rotated_img_features = {0: straight_features}

            rotated_images = {0: images}

            for rot_i in range(1, 4):
                rot_images = []
                for img, img_size in zip(images.tensors, images.image_sizes):
                    rot_image, rot_index = SelfSup_Scrambler.rotate_single(
                        img, rot_i)
                    rot_images.append(rot_image)

                # need to move to gpu?
                stacked_tensor = torch.stack(rot_images)
                r_features = self.backbone(stacked_tensor)
                rotated_img_features[rot_i] = r_features
                rotated_images[rot_i] = to_image_list(rot_images)

        if targets is not None or not self.training:

            proposals, proposal_losses = self.rpn(images, features, targets)
            if self.roi_heads:
                x, result, detector_losses = self.roi_heads(
                    features, proposals, targets)
            else:
                # RPN-only models don't have roi_heads
                x = features
                result = proposals
                detector_losses = {}

        losses = {}

        pseudo_targets = None

        if auxiliary_task and self.cfg.MODEL.SELF_SUPERVISOR.TYPE == "rotation":
            # we use the *result* (a boxlist) as a list of boxes to perform the self supervised task
            if self.cfg.MODEL.SELF_SUPERVISOR.REGIONS == "detections" or (
                (self.cfg.MODEL.SELF_SUPERVISOR.REGIONS == "targets")
                    and targets is None):

                test_result = self.obtain_pseudo_labels(images, features)

            elif self.cfg.MODEL.SELF_SUPERVISOR.REGIONS == "targets":

                test_result = targets

            elif self.cfg.MODEL.SELF_SUPERVISOR.REGIONS == "images":

                image_sizes = images.image_sizes
                test_result = []

                for height, width in image_sizes:
                    xmin = 0
                    ymin = 0
                    xmax = width
                    ymax = height
                    bbox = torch.tensor([[xmin, ymin, xmax, ymax]],
                                        dtype=torch.float)
                    boxlist = BoxList(bbox, (width, height))

                    boxlist = boxlist.to(images.tensors.device)
                    test_result.append(boxlist)

            elif self.cfg.MODEL.SELF_SUPERVISOR.REGIONS == "crop":

                image_sizes = images.image_sizes

                test_result = []

                for height, width in image_sizes:
                    xmin, ymin, xmax, ymax = self.random_crop_image(
                        width, height)

                    bbox = torch.tensor([[xmin, ymin, xmax, ymax]],
                                        dtype=torch.float)
                    boxlist = BoxList(bbox, (width, height))

                    boxlist = boxlist.to(images.tensors.device)

                    test_result.append(boxlist)

            rotated_regions = {0: test_result}
            for rot_i in range(1, 4):
                r_result = [res[::] for res in test_result]

                for idx, box_list in enumerate(r_result):
                    rotated_boxes = box_list.transpose(rot_i + 1)

                    r_result[idx] = rotated_boxes

                rotated_regions[rot_i] = r_result

            # log images
            #for rot_i in range(0, 4):
            #    self._log_image_tensorboard(rotated_images[rot_i], rotated_regions[rot_i], rot_i)

            pooling_res = []
            rot_target_batch = []
            for idx_in_batch in range(len(test_result)):
                mul = 1
                rot_target = torch.ones((len(test_result[idx_in_batch]) * mul),
                                        dtype=torch.long)
                for r in range(len(test_result[idx_in_batch])):
                    rot = random.randint(0, 3)
                    features_r = rotated_img_features[rot]
                    regions_r = rotated_regions[rot][idx_in_batch][[r]]
                    l_regions_r = [regions_r]
                    pooled_features = self.region_feature_extractor(
                        features_r, l_regions_r)
                    pooled_features = self.ss_adaptive_pooling(pooled_features)
                    pooled_features = pooled_features.view(
                        pooled_features.size(0), -1)
                    class_preds = self.ss_classifier(
                        self.ss_dropout(pooled_features))
                    pooling_res.append(class_preds)
                    rot_target[r] = rot
                rot_target_batch.append(rot_target)

            if len(pooling_res) > 0:
                pooling_res = torch.stack(pooling_res).squeeze(dim=1)
                rot_target_batch = torch.cat(rot_target_batch).to(
                    pooling_res.device)
                aux_loss = self.ss_criterion(pooling_res, rot_target_batch)
                aux_loss = aux_loss.mean()
                # add to dictionary of losses
                losses["aux_loss"] = aux_loss

        if self.training:
            if targets is not None:
                losses.update(detector_losses)
                losses.update(proposal_losses)
            self.global_step += 1
            return losses

        return result
Пример #5
0
    def forward(self,
                anchors,
                objectness,
                box_regression,
                targets=None,
                centerness=None,
                rpn_center_box_regression=None,
                centerness_pack=None):
        """
        Arguments:
            anchors: list[list[BoxList]]
            objectness: list[tensor]
            box_regression: list[tensor]

        Returns:
            boxlists (list[BoxList]): the post-processed anchors, after
                applying box decoding and NMS
        """
        sampled_boxes = []
        num_levels = len(objectness)
        anchors = list(zip(*anchors))
        for a, o, b in zip(anchors, objectness, box_regression):
            sampled_boxes.append(self.forward_for_single_feature_map(a, o, b))

        boxlists = list(zip(*sampled_boxes))
        boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]

        if num_levels > 1:
            boxlists = self.select_over_all_levels(boxlists)

        # append ground-truth bboxes to proposals
        if self.training and targets is not None:
            boxlists = self.add_gt_proposals(boxlists, targets)

        if self.pred_targets:
            pred_targets = []
            if True:
                for img_centerness, center_box_reg in zip(
                        centerness, rpn_center_box_regression):
                    # gt_centerness, gt_bbox, anchor_bbox = center_target
                    # print(rpn_center_box_regression, anchor_bbox)
                    # gt_mask = gt_centerness.detach().cpu().numpy() > 0.0
                    img_centerness = img_centerness[0, :, :]

                    center_box_reg = center_box_reg[:, :, :].permute(1, 2, 0)

                    anchor_bbox = np.zeros(shape=(center_box_reg.shape[0],
                                                  center_box_reg.shape[1], 4))
                    for xx in range(anchor_bbox.shape[1]):
                        for yy in range(anchor_bbox.shape[0]):
                            anchor_bbox[yy, xx, :] = [
                                max(0.0, xx * 4 - 16),
                                max(0.0, yy * 4 - 16),
                                min(xx * 4 + 16, boxlists[0].size[0]),
                                min(yy * 4 + 16, boxlists[0].size[1])
                            ]
                    anchor_bbox = torch.as_tensor(anchor_bbox,
                                                  device=center_box_reg.device)

                    # print(center_box_reg.shape, anchor_bbox.shape)
                    boxes = self.box_coder.decode(
                        center_box_reg.reshape(-1, 4), anchor_bbox.view(-1, 4))

                    pred_target = None
                    pred_score = torch.sigmoid(
                        img_centerness.detach()).cpu().numpy()
                    pred_mask = pred_score > 0.95
                    # print(gt_mask.shape, pred_mask.shape)
                    imllabel, numlabel = scipy.ndimage.label(pred_mask)
                    if numlabel > 0:
                        valid = np.zeros(shape=(numlabel, ), dtype=np.bool)
                        box_inds = []
                        for ano in range(1, numlabel + 1):
                            mask = imllabel == ano
                            valid[ano - 1] = True  #  gt_mask[mask].sum() == 0
                            box_inds.append(np.argmax(pred_score * mask))
                        if np.any(valid):
                            boxes = boxes[box_inds, :]
                            # print(box_inds, boxes, anchor_bbox.view(-1, 4)[box_inds, :], gt_bbox.view(-1, 4)[box_inds, :])
                            pred_target = BoxList(torch.as_tensor(boxes),
                                                  boxlists[0].size,
                                                  mode="xyxy")
                            pred_target.clip_to_image()
                            pred_target = pred_target.to(img_centerness.device)
                            # print(img_centerness.device, pred_target.bbox.device)
                    pred_targets.append(pred_target)
            else:
                for img_centerness in centerness:
                    pred_target = None
                    pred_mask = torch.sigmoid(
                        img_centerness[0, :, :].detach()).cpu().numpy() > 0.95
                    # print(gt_mask.shape, pred_mask.shape)
                    imllabel, numlabel = scipy.ndimage.label(pred_mask)
                    if numlabel > 0:
                        masks = np.zeros(shape=(pred_mask.shape[0],
                                                pred_mask.shape[1], numlabel),
                                         dtype=np.uint8)
                        valid = np.zeros(shape=(numlabel, ), dtype=np.bool)
                        for ano in range(1, numlabel + 1):
                            mask = imllabel == ano
                            valid[ano - 1] = True
                            masks[:, :, ano - 1] = mask
                        if np.any(valid):
                            masks = masks[:, :, valid]
                            boxes = extract_bboxes(masks)
                            pred_target = BoxList(torch.as_tensor(boxes),
                                                  boxlists[0].size,
                                                  mode="xyxy")
                            pred_target.clip_to_image()
                            pred_target = pred_target.to(img_centerness.device)
                            # print(img_centerness.device, pred_target.bbox.device)
                    pred_targets.append(pred_target)

            if True:
                if not self.training:
                    print('add', [
                        len(pred_target)
                        for pred_target in pred_targets if pred_target
                    ], 'proposals')
                boxlists = self.add_pred_proposals(boxlists, pred_targets)
        else:
            pred_targets = None

        return boxlists, pred_targets