Ejemplo n.º 1
0
    def crop_and_resize(self,
                        bboxes,
                        out_shape,
                        inds,
                        device='cpu',
                        interpolation='bilinear',
                        binarize=True):
        """See :func:`BaseInstanceMasks.crop_and_resize`."""
        if len(self.masks) == 0:
            empty_masks = np.empty((0, *out_shape), dtype=np.uint8)
            return BitmapMasks(empty_masks, *out_shape)

        # convert bboxes to tensor
        if isinstance(bboxes, np.ndarray):
            bboxes = torch.from_numpy(bboxes).to(device=device)
        if isinstance(inds, np.ndarray):
            inds = torch.from_numpy(inds).to(device=device)

        num_bbox = bboxes.shape[0]
        fake_inds = torch.arange(
            num_bbox, device=device).to(dtype=bboxes.dtype)[:, None]
        rois = torch.cat([fake_inds, bboxes], dim=1)  # Nx5
        rois = rois.to(device=device)
        if num_bbox > 0:
            gt_masks_th = torch.from_numpy(self.masks).to(device).index_select(
                0, inds).to(dtype=rois.dtype)
            targets = roi_align(gt_masks_th[:, None, :, :], rois, out_shape,
                                1.0, 0, 'avg', True).squeeze(1)
            if binarize:
                resized_masks = (targets >= 0.5).cpu().numpy()
            else:
                resized_masks = targets.cpu().numpy()
        else:
            resized_masks = []
        return BitmapMasks(resized_masks, *out_shape)
Ejemplo n.º 2
0
    def forward(self, instance_feats, semantic_feat, semantic_pred, rois,
                roi_labels):
        concat_tensors = [instance_feats]

        # instance-wise semantic feats
        semantic_feat = self.relu(self.semantic_transform_in(semantic_feat))
        ins_semantic_feats = self.semantic_roi_extractor([
            semantic_feat,
        ], rois)
        ins_semantic_feats = self.relu(
            self.semantic_transform_out(ins_semantic_feats))
        concat_tensors.append(ins_semantic_feats)

        # instance masks
        instance_preds = self.instance_logits(instance_feats)[
            torch.arange(len(rois)), roi_labels][:, None]
        _instance_preds = instance_preds.sigmoid(
        ) if self.mask_use_sigmoid else instance_preds
        instance_masks = F.interpolate(_instance_preds,
                                       instance_feats.shape[-2],
                                       mode='bilinear',
                                       align_corners=True)
        concat_tensors.append(instance_masks)

        # instance-wise semantic masks
        fake_rois = rois.clone()
        fake_rois[:, 0] = torch.zeros(len(rois))
        _semantic_pred = semantic_pred.sigmoid(
        ) if self.mask_use_sigmoid else semantic_pred
        ins_semantic_masks = roi_align(_semantic_pred, fake_rois,
                                       instance_feats.shape[-2:],
                                       1.0 / self.semantic_out_stride, 0,
                                       'avg', True)
        ins_semantic_masks = F.interpolate(ins_semantic_masks,
                                           instance_feats.shape[-2:],
                                           mode='bilinear',
                                           align_corners=True)
        concat_tensors.append(ins_semantic_masks)

        # fuse instance feats & instance masks & semantic feats & semantic masks
        fused_feats = torch.cat(concat_tensors, dim=1)
        for conv in self.fuse_conv:
            fused_feats = self.relu(conv(fused_feats))

        fused_feats = self.relu(self.fuse_transform_out(fused_feats))
        fused_feats = self.relu(self.upsample(fused_feats))

        # concat instance and semantic masks with fused feats again
        instance_masks = F.interpolate(_instance_preds,
                                       fused_feats.shape[-2],
                                       mode='bilinear',
                                       align_corners=True)
        ins_semantic_masks = F.interpolate(ins_semantic_masks,
                                           fused_feats.shape[-2],
                                           mode='bilinear',
                                           align_corners=True)
        fused_feats = torch.cat(
            [fused_feats, instance_masks, ins_semantic_masks], dim=1)

        return instance_preds, fused_feats