Example #1
0
 def crop_mask(self, mask, boxes, h, w):
     mask = interpolate(mask,
                        size=(h, w),
                        mode='bilinear',
                        align_corners=False)
     prob = torch.zeros((len(boxes), 1, 28, 28),
                        device=boxes.device,
                        dtype=mask.dtype)
     for i, box in enumerate(boxes):
         x1, y1, x2, y2 = box.int()
         prob[i, 0] = interpolate(mask[:, :, x1:x2, y1:y2],
                                  size=(28, 28),
                                  mode='bilinear',
                                  align_corners=False)[0, 0]
     return prob
Example #2
0
    def compute_mask(self, mask, boxes, h, w):
        mask = interpolate(mask,
                           size=(h, w),
                           mode='bilinear',
                           align_corners=False)
        mask = mask.permute(0, 2, 3, 1)
        M = int(math.sqrt(mask.shape[-1]))
        prob = torch.zeros((len(boxes), 1, M, M),
                           device=boxes.device,
                           dtype=mask.dtype)
        for i, box in enumerate(boxes):
            x1, y1, x2, y2 = box.int()
            x1 = int(max(0, x1))
            x2 = int(min(x2, h - 1))
            y1 = int(max(0, y1))
            y2 = int(min(y2, w - 1))
            if x1 >= x2 or y1 >= y2:
                continue
            #print(mask.shape)
            #print(x1, x2, y1, y2)

            #loc = self.compute_location(x1, x2, y1, y2, mask.device)
            #print(loc[:,0].min(), loc[:,0].max(), loc[:,1].min(), loc[:,1].max())
            #print()
            #prob[i, 0] = mask[:, loc[:,0], loc[:,1], :].mean(dim=1).reshape(M, M)
            prob[i, 0] = mask[:, (x1 + x2) // 2,
                              (y1 + y2) // 2, :].reshape(M, M)
        return prob
Example #3
0
    def forward(self, locations, box_cls, box_regression, centerness, proposal_embed, proposal_margin, pixel_embed, image_sizes, targets):
        """
        Arguments:
            anchors: list[list[BoxList]]
            box_cls: list[tensor]
            box_regression: list[tensor]
            image_sizes: list[(h, w)]
        Returns:
            boxlists (list[BoxList]): the post-processed anchors, after
                applying box decoding and NMS
        """
        sampled_boxes = []
        for i, (l, o, b, c) in enumerate(zip(locations, box_cls, box_regression, centerness)):
            em = proposal_embed[i]
            mar = proposal_margin[i]
            if self.fix_margin:
                mar = torch.ones_like(mar) * self.init_margin
            sampled_boxes.append(
                self.forward_for_single_feature_map(
                    l, o, b, c, em, mar, image_sizes, i
                )
            )
        boxlists = list(zip(*sampled_boxes))
        boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
        boxlists = self.select_over_all_levels(boxlists)

        # resize pixel embedding for higher resolution
        N, dim, m_h, m_w = pixel_embed.shape
        o_h = m_h * self.mask_scale_factor
        o_w = m_w * self.mask_scale_factor
        pixel_embed = interpolate(pixel_embed, size=(o_h, o_w), mode='bilinear', align_corners=False)

        boxlists = self.forward_for_mask(boxlists, pixel_embed)

        return boxlists
Example #4
0
    def prepare_masks(self, o_h, o_w, r_h, r_w, targets_masks):
        masks = []
        for im_i in range(len(targets_masks)):
            mask_t = targets_masks[im_i]
            if len(mask_t) == 0:
                masks.append(mask_t.new_tensor([]))
                continue
            n, h, w = mask_t.shape
            mask = mask_t.new_zeros((n, r_h, r_w))
            mask[:, :h, :w] = mask_t
            resized_mask = interpolate(
                input=mask.float().unsqueeze(0),
                size=(o_h, o_w),
                mode="bilinear",
                align_corners=False,
            )[0].gt(0)

            masks.append(resized_mask)

        return masks
Example #5
0
    def compute_single_instance_mask(self, masks):
        instances = torch.split(masks, [1] * len(masks), dim=0)
        instances = sorted(instances, key=lambda x: x.sum(), reverse=True)
        re = instances[0]
        for i, item in enumerate(instances[1:]):
            re = re * (1 - item) + item * (i + 2)

        size = int(math.sqrt(self.box_mask_pw_channels))
        obj = []
        #obj.append(torch.zeros((1, self.box_mask_pw_channels), device=instances[0].device, dtype=instances[0].dtype))
        for item in instances:
            loc = torch.nonzero(item)
            xmin, xmax, ymin, ymax = loc[:, 1].min(
            ), loc[:, 1].max() + 1, loc[:, 2].min(), loc[:, 2].max() + 1
            tmp = interpolate(item[:, xmin:xmax,
                                   ymin:ymax].unsqueeze(0).float(),
                              size=(size, size),
                              mode='bilinear',
                              align_corners=False) > 0
            obj.append(tmp.squeeze(0).reshape(1, -1))
        obj.insert(0, torch.zeros_like(obj[0]))

        return re, torch.cat(obj, dim=0)
Example #6
0
def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
    padded_mask, scale = expand_masks(mask[None], padding=padding)
    mask = padded_mask[0, 0]
    box = expand_boxes(box[None], scale)[0]
    box = box.to(dtype=torch.int32)

    TO_REMOVE = 1
    w = int(box[2] - box[0] + TO_REMOVE)
    h = int(box[3] - box[1] + TO_REMOVE)
    w = max(w, 1)
    h = max(h, 1)

    # Set shape to [batchxCxHxW]
    mask = mask.expand((1, 1, -1, -1))

    # Resize mask
    mask = mask.to(torch.float32)
    mask = interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
    mask = mask[0][0]

    if thresh >= 0:
        mask = mask > thresh
    else:
        # for visualization and debugging, we also
        # allow it to return an unmodified mask
        mask = (mask * 255).to(torch.uint8)

    im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8)
    x_0 = max(box[0], 0)
    x_1 = min(box[2] + 1, im_w)
    y_0 = max(box[1], 0)
    y_1 = min(box[3] + 1, im_h)

    im_mask[y_0:y_1, x_0:x_1] = mask[
        (y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])
    ]
    return im_mask
Example #7
0
    def compute_targets_for_locations(self, locations, feature_size, targets,
                                      object_sizes_of_interest):
        labels = []
        reg_targets = []
        mask_targets = []
        xs, ys = locations[:, 0], locations[:, 1]
        device = targets[0].bbox.device

        for im_i in range(len(targets)):
            targets_per_im = targets[im_i]
            assert targets_per_im.mode == "xyxy"
            bboxes = targets_per_im.bbox
            labels_per_im = targets_per_im.get_field("labels")
            masks_per_im = targets_per_im.get_field(
                'masks').get_mask_tensor().to(device)
            area = targets_per_im.area()

            if len(masks_per_im.size()) < 3:
                masks_per_im = masks_per_im.unsqueeze(0)
            instance_mask, instances = self.compute_single_instance_mask(
                masks_per_im)

            masks = []
            for size in feature_size:
                with torch.no_grad():
                    resized_masks_per_im = interpolate(
                        instance_mask.unsqueeze(0).float(),
                        size=size,
                        mode='bilinear',
                        align_corners=False
                    )  #F.adaptive_avg_pool2d(Variable(instance_mask.float()), size).data
                masks.append(instances[
                    resized_masks_per_im.squeeze().long()].float().reshape(
                        size[0] * size[1], -1))
            masks = torch.cat(masks, dim=0)

            l = xs[:, None] - bboxes[:, 0][None]
            t = ys[:, None] - bboxes[:, 1][None]
            r = bboxes[:, 2][None] - xs[:, None]
            b = bboxes[:, 3][None] - ys[:, None]
            reg_targets_per_im = torch.stack([l, t, r, b], dim=2)

            is_in_boxes = reg_targets_per_im.min(dim=2)[0] > 0

            max_reg_targets_per_im = reg_targets_per_im.max(dim=2)[0]
            # limit the regression range for each location
            is_cared_in_the_level = \
                (max_reg_targets_per_im >= object_sizes_of_interest[:, [0]]) & \
                (max_reg_targets_per_im <= object_sizes_of_interest[:, [1]])

            locations_to_gt_area = area[None].repeat(len(locations), 1)
            locations_to_gt_area[is_in_boxes == 0] = INF
            locations_to_gt_area[is_cared_in_the_level == 0] = INF

            # if there are still more than one objects for a location,
            # we choose the one with minimal area
            locations_to_min_area, locations_to_gt_inds = locations_to_gt_area.min(
                dim=1)

            reg_targets_per_im = reg_targets_per_im[range(len(locations)),
                                                    locations_to_gt_inds]
            labels_per_im = labels_per_im[locations_to_gt_inds]
            labels_per_im[locations_to_min_area == INF] = 0

            #masks = masks[range(len(locations)), locations_to_gt_inds]
            masks[locations_to_min_area == INF] = 0
            #masks = (masks.sum(dim=1) > 0).float()

            labels.append(labels_per_im)
            reg_targets.append(reg_targets_per_im)
            mask_targets.append(masks)

        return labels, reg_targets, mask_targets
Example #8
0
    def __call__(self, locations, box_cls, box_regression, centerness,
                 proposal_embed, proposal_margin, pixel_embed, targets):
        """
        Arguments:
            locations (list[BoxList])
            box_cls (list[Tensor])
            box_regression (list[Tensor])
            centerness (list[Tensor])
            targets (list[BoxList])

        Returns:
            cls_loss (Tensor)
            reg_loss (Tensor)
            centerness_loss (Tensor)
        """
        num_classes = box_cls[0].size(1)
        im_h = box_cls[4].shape[2] * self.fpn_strides[4]
        im_w = box_cls[4].shape[3] * self.fpn_strides[4]
        labels_per_level, reg_targets_per_level, labels, reg_targets, matched_idxes = self.prepare_targets(
            locations, targets, im_w, im_h)

        box_cls_flatten = []
        box_regression_flatten = []
        centerness_flatten = []
        labels_flatten = []
        reg_targets_flatten = []
        for l in range(len(labels_per_level)):
            box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape(
                -1, num_classes))
            box_regression_flatten.append(box_regression[l].permute(
                0, 2, 3, 1).reshape(-1, 4))
            labels_flatten.append(labels_per_level[l].reshape(-1))
            reg_targets_flatten.append(reg_targets_per_level[l].reshape(-1, 4))
            centerness_flatten.append(centerness[l].reshape(-1))

        box_cls_flatten = torch.cat(box_cls_flatten, dim=0)
        box_regression_flatten = torch.cat(box_regression_flatten, dim=0)
        centerness_flatten = torch.cat(centerness_flatten, dim=0)
        labels_flatten = torch.cat(labels_flatten, dim=0)
        reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0)

        pos_inds = torch.nonzero(labels_flatten > 0).squeeze(1)

        box_regression_flatten = box_regression_flatten[pos_inds]
        reg_targets_flatten = reg_targets_flatten[pos_inds]
        centerness_flatten = centerness_flatten[pos_inds]

        num_gpus = get_num_gpus()
        # sync num_pos from all gpus
        total_num_pos = reduce_sum(pos_inds.new_tensor([pos_inds.numel()
                                                        ])).item()
        num_pos_avg_per_gpu = max(total_num_pos / float(num_gpus), 1.0)

        cls_loss = self.cls_loss_func(
            box_cls_flatten, labels_flatten.int()) / num_pos_avg_per_gpu

        if pos_inds.numel() > 0:
            centerness_targets = self.compute_centerness_targets(
                reg_targets_flatten)

            # average sum_centerness_targets from all gpus,
            # which is used to normalize centerness-weighed reg loss
            sum_centerness_targets_avg_per_gpu = \
                reduce_sum(centerness_targets.sum()).item() / float(num_gpus)
            reg_loss = self.box_reg_loss_func(
                box_regression_flatten, reg_targets_flatten,
                centerness_targets) / sum_centerness_targets_avg_per_gpu
            centerness_loss = self.centerness_loss_func(
                centerness_flatten, centerness_targets) / num_pos_avg_per_gpu
        else:
            reg_loss = box_regression_flatten.sum()
            reduce_sum(centerness_flatten.new_tensor([0.0]))
            centerness_loss = centerness_flatten.sum()

        #################################### Mask Related Losses ######################################
        # get positive proposal labels for each gt instance
        pos_proposal_labels_for_targets = self.get_pos_proposal_indexes(
            locations, box_regression, matched_idxes, targets)

        # get positive samples of embeddings & margins for each gt instance
        proposal_embed_for_targets, valids_for_targets = self.get_proposal_element(
            proposal_embed, pos_proposal_labels_for_targets)
        proposal_margin_for_targets, _ = self.get_proposal_element(
            proposal_margin, pos_proposal_labels_for_targets)

        ######## MEANINGLESS_LOSS #######
        mask_loss = box_cls[0].new_tensor(0.0)
        for i in range(len(proposal_embed)):
            mask_loss += 0 * proposal_embed[i].sum()
            mask_loss += 0 * proposal_margin[i].sum()
        mask_loss += 0 * pixel_embed.sum()
        ############ Mask Losses ##############
        # get target masks in prefer size
        N, _, m_h, m_w = pixel_embed.shape
        o_h = m_h * self.mask_scale_factor
        o_w = m_w * self.mask_scale_factor
        r_h = int(m_h * self.fpn_strides[0])
        r_w = int(m_w * self.fpn_strides[0])
        stride = self.fpn_strides[0] / self.mask_scale_factor
        targets_masks = [
            target_im.get_field('masks').convert('mask').instances.masks.to(
                device=pixel_embed.device) for target_im in targets
        ]
        masks_t = self.prepare_masks(o_h, o_w, r_h, r_w, targets_masks)
        pixel_embed = interpolate(input=pixel_embed,
                                  size=(o_h, o_w),
                                  mode="bilinear",
                                  align_corners=False)

        if self.loss_mask_alpha > 0:
            for im in range(N):
                valid = valids_for_targets[im]
                if valid.sum() == 0:
                    continue
                proposal_embed_im = proposal_embed_for_targets[im][valid]
                proposal_margin_im = proposal_margin_for_targets[im][valid]
                masks_t_im = masks_t[im][valid]
                boxes_t_im = targets[im].bbox[valid] / stride

                masks_prob = self.compute_mask_prob(proposal_embed_im,
                                                    proposal_margin_im,
                                                    pixel_embed[im])
                if self.box_padding >= 0:
                    masks_prob_crop, crop_mask = crop_by_box(
                        masks_prob, boxes_t_im, self.box_padding)
                    mask_loss_per_target = self.mask_loss_func(masks_prob_crop,
                                                               masks_t_im,
                                                               mask=crop_mask,
                                                               act=True)
                else:
                    mask_loss_per_target = self.mask_loss_func(masks_prob,
                                                               masks_t_im,
                                                               act=True)

                mask_loss += mask_loss_per_target.mean()

            mask_loss = mask_loss / N * self.loss_mask_alpha

        return cls_loss, reg_loss, centerness_loss, mask_loss