예제 #1
0
    def _match_to_lbl(self, proposals, bbx, cat, ids, msk, match):
        cls_lbl = []
        bbx_lbl = []
        msk_lbl = []
        for i, (proposals_i, bbx_i, cat_i, ids_i, msk_i,
                match_i) in enumerate(zip(proposals, bbx, cat, ids, msk,
                                          match)):
            if match_i is not None:
                pos = match_i >= 0

                # Objectness labels
                cls_lbl_i = proposals_i.new_zeros(proposals_i.size(0),
                                                  dtype=torch.long)
                cls_lbl_i[pos] = cat_i[match_i[pos]] + 1 - self.num_stuff

                # Bounding box regression labels
                if pos.any().item():
                    bbx_lbl_i = calculate_shift(proposals_i[pos],
                                                bbx_i[match_i[pos]])
                    bbx_lbl_i *= bbx_lbl_i.new(self.bbx_reg_weights)

                    iis_lbl_i = ids_i[match_i[pos]]

                    # Compute instance segmentation masks
                    msk_i = roi_sampling(
                        msk_i.unsqueeze(0),
                        proposals_i[pos],
                        msk_i.new_zeros(pos.long().sum().item()),
                        self.lbl_roi_size,  #28*28
                        interpolation="nearest")

                    # Calculate mask segmentation labels
                    msk_lbl_i = (msk_i == iis_lbl_i.view(
                        -1, 1, 1, 1)).any(dim=1).to(torch.long)
                    if not self.void_is_background:
                        msk_lbl_i[(msk_i == 0).all(dim=1)] = -1
                else:
                    bbx_lbl_i = None
                    msk_lbl_i = None

                cls_lbl.append(cls_lbl_i)
                bbx_lbl.append(bbx_lbl_i)
                msk_lbl.append(msk_lbl_i)
            else:
                cls_lbl.append(None)
                bbx_lbl.append(None)
                msk_lbl.append(None)

        return PackedSequence(cls_lbl), PackedSequence(
            bbx_lbl), PackedSequence(msk_lbl)
예제 #2
0
 def _rois(self, x, proposals, proposals_idx, img_size):
     stride = proposals.new(
         [fs / os for fs, os in zip(x.shape[-2:], img_size)])
     proposals = (proposals - 0.5) * stride.repeat(2) + 0.5
     return roi_sampling(x, proposals, proposals_idx, self.roi_size)
예제 #3
0
def save_prediction_image(raw_pred, img_info, out_dir, colors, num_stuff,
                          threshold, obj_cls):
    bbx_pred, cls_pred, obj_pred, msk_pred = raw_pred
    img = Image.open(img_info["abs_path"])
    draw = ImageDraw.Draw(img)

    # Prepare folders and paths
    folder, img_name = path.split(img_info["rel_path"])
    img_name, _ = path.splitext(img_name)
    out_dir = path.join(out_dir, folder)
    ensure_dir(out_dir)
    out_path = path.join(out_dir, img_name + ".jpg")

    # Rescale bounding boxes
    scale_factor = [
        os / bs
        for os, bs in zip(img_info["original_size"], img_info["batch_size"])
    ]
    bbx_pred[:, [0, 2]] = bbx_pred[:, [0, 2]] * scale_factor[0]
    bbx_pred[:, [1, 3]] = bbx_pred[:, [1, 3]] * scale_factor[1]

    # Expand masks
    bbx_inv = invert_roi_bbx(bbx_pred, list(msk_pred.shape[-2:]),
                             list(img_info["original_size"]))
    bbx_idx = torch.arange(0, msk_pred.size(0), dtype=torch.long)
    msk_pred = roi_sampling(msk_pred.cpu().unsqueeze(1).sigmoid(),
                            bbx_inv.cpu(),
                            bbx_idx,
                            list(img_info["original_size"]),
                            padding="zero")
    msk_pred = msk_pred.squeeze(1) > 0.5

    print(type(msk_pred))

    id = 1
    for bbx_pred_i, cls_pred_i, obj_pred_i, msk_pred_i in zip(
            bbx_pred, cls_pred, obj_pred, msk_pred):
        color = colors[cls_pred_i.item() + num_stuff]

        if obj_pred_i.item() > threshold:
            if not obj_cls:
                # detect ALL
                msk = Image.fromarray(msk_pred_i.numpy() * 192)
                draw.bitmap((0, 0), msk, tuple(color))

                draw.rectangle((
                    bbx_pred_i[1].item(),
                    bbx_pred_i[0].item(),
                    bbx_pred_i[3].item(),
                    bbx_pred_i[2].item(),
                ),
                               outline=tuple(color),
                               width=3)

            if str(cls_pred_i.item()) == obj_cls:
                # detect specific class
                msk = Image.fromarray(msk_pred_i.numpy() * 192)
                draw.bitmap((0, 0), msk, tuple(color))

                draw.rectangle((
                    bbx_pred_i[1].item(),
                    bbx_pred_i[0].item(),
                    bbx_pred_i[3].item(),
                    bbx_pred_i[2].item(),
                ),
                               outline=tuple(color),
                               width=3)

                _dict = {
                    'id': id,
                    'bbox': {
                        'x1': bbx_pred_i[0].item(),
                        'y1': bbx_pred_i[1].item(),
                        'x2': bbx_pred_i[2].item(),
                        'y2': bbx_pred_i[3].item(),
                    },
                    'cls_pred': cls_pred_i.item(),
                    'obj_pred': obj_pred_i.item(),
                    'msk_pred': msk_pred_i.data.numpy(),
                }
                if not path.isdir(path.join(out_dir, img_name)):
                    mkdir(path.join(out_dir, img_name))

                np.save(path.join(out_dir, img_name + '/' + str(id)), _dict)
                id = id + 1

    img.convert(mode="RGB").save(out_path)