def crop_and_resize(self, bboxes, out_shape, inds, device='cpu', interpolation='bilinear', binarize=True): """See :func:`BaseInstanceMasks.crop_and_resize`.""" if len(self.masks) == 0: empty_masks = np.empty((0, *out_shape), dtype=np.uint8) return BitmapMasks(empty_masks, *out_shape) # convert bboxes to tensor if isinstance(bboxes, np.ndarray): bboxes = torch.from_numpy(bboxes).to(device=device) if isinstance(inds, np.ndarray): inds = torch.from_numpy(inds).to(device=device) num_bbox = bboxes.shape[0] fake_inds = torch.arange( num_bbox, device=device).to(dtype=bboxes.dtype)[:, None] rois = torch.cat([fake_inds, bboxes], dim=1) # Nx5 rois = rois.to(device=device) if num_bbox > 0: gt_masks_th = torch.from_numpy(self.masks).to(device).index_select( 0, inds).to(dtype=rois.dtype) targets = roi_align(gt_masks_th[:, None, :, :], rois, out_shape, 1.0, 0, 'avg', True).squeeze(1) if binarize: resized_masks = (targets >= 0.5).cpu().numpy() else: resized_masks = targets.cpu().numpy() else: resized_masks = [] return BitmapMasks(resized_masks, *out_shape)
def forward(self, instance_feats, semantic_feat, semantic_pred, rois, roi_labels): concat_tensors = [instance_feats] # instance-wise semantic feats semantic_feat = self.relu(self.semantic_transform_in(semantic_feat)) ins_semantic_feats = self.semantic_roi_extractor([ semantic_feat, ], rois) ins_semantic_feats = self.relu( self.semantic_transform_out(ins_semantic_feats)) concat_tensors.append(ins_semantic_feats) # instance masks instance_preds = self.instance_logits(instance_feats)[ torch.arange(len(rois)), roi_labels][:, None] _instance_preds = instance_preds.sigmoid( ) if self.mask_use_sigmoid else instance_preds instance_masks = F.interpolate(_instance_preds, instance_feats.shape[-2], mode='bilinear', align_corners=True) concat_tensors.append(instance_masks) # instance-wise semantic masks fake_rois = rois.clone() fake_rois[:, 0] = torch.zeros(len(rois)) _semantic_pred = semantic_pred.sigmoid( ) if self.mask_use_sigmoid else semantic_pred ins_semantic_masks = roi_align(_semantic_pred, fake_rois, instance_feats.shape[-2:], 1.0 / self.semantic_out_stride, 0, 'avg', True) ins_semantic_masks = F.interpolate(ins_semantic_masks, instance_feats.shape[-2:], mode='bilinear', align_corners=True) concat_tensors.append(ins_semantic_masks) # fuse instance feats & instance masks & semantic feats & semantic masks fused_feats = torch.cat(concat_tensors, dim=1) for conv in self.fuse_conv: fused_feats = self.relu(conv(fused_feats)) fused_feats = self.relu(self.fuse_transform_out(fused_feats)) fused_feats = self.relu(self.upsample(fused_feats)) # concat instance and semantic masks with fused feats again instance_masks = F.interpolate(_instance_preds, fused_feats.shape[-2], mode='bilinear', align_corners=True) ins_semantic_masks = F.interpolate(ins_semantic_masks, fused_feats.shape[-2], mode='bilinear', align_corners=True) fused_feats = torch.cat( [fused_feats, instance_masks, ins_semantic_masks], dim=1) return instance_preds, fused_feats