Exemplo n.º 1
0
    def mask_heads_forward_with_coords(
            self, mask_feats, mask_feat_stride, instances
    ):
        # 1/8 P3 对应的原图真实坐标
        locations = compute_locations(
            mask_feats.size(2), mask_feats.size(3),
            stride=mask_feat_stride, device=mask_feats.device
        )
        n_inst = len(instances)

        # 在fcos每个点采样时,会记录所属img id,最后根据分类pos id筛选,
        # 最后在实例分割这块根据ins个数复制相应image出来的mask feature
        # 即(2, n, h, w) -> (ins, n, hxw)
        im_inds = instances.im_inds
        mask_head_params = instances.mask_head_params

        N, _, H, W = mask_feats.size()

        if not self.disable_rel_coords:
            # 之前在fcos记录了每个pos location的中心位置,在这里生成相对坐标
            # 即target所在的点变成0,其余的变成和它的相对距离
            instance_locations = instances.locations
            # (39, 1, 2) - (1, hxw, 2)
            relative_coords = instance_locations.reshape(-1, 1, 2) - locations.reshape(1, -1, 2)
            # (39, hxw, 2) --> (39, 2, hxw)
            relative_coords = relative_coords.permute(0, 2, 1).float()
            # 给每个相对距离乘以一个衰减系数,如果instance越大,即来自高层特征,
            # 则会给它的相对距离更大的衰减因子
            soi = self.sizes_of_interest.float()[instances.fpn_levels]
            relative_coords = relative_coords / soi.reshape(-1, 1, 1)
            relative_coords = relative_coords.to(dtype=mask_feats.dtype)

            mask_head_inputs = torch.cat([
                relative_coords, mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W)
            ], dim=1)
            # torch.Size([39, 10, 12880]) torch.Size([2, 8, 92, 140]) torch.Size([39, 2, 12880])
            # print(mask_head_inputs.shape, mask_feats.shape, relative_coords.shape)
        else:
            mask_head_inputs = mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W)

        mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W)

        weights, biases = parse_dynamic_params(
            mask_head_params, self.channels,
            self.weight_nums, self.bias_nums
        )

        # 现在的mask_head_inputs的每个instance包含了全图的mask feature
        # 以及所预测instance中心点的相对距离信息
        # torch.Size([1, 580, 100, 136])
        # [torch.Size([464, 10, 1, 1]), torch.Size([464, 8, 1, 1]), torch.Size([58, 8, 1, 1])]
        # [torch.Size([464]), torch.Size([464]), torch.Size([58])]
        mask_logits = self.mask_heads_forward(mask_head_inputs, weights, biases, n_inst)
        mask_logits = mask_logits.reshape(-1, 1, H, W)

        assert mask_feat_stride >= self.mask_out_stride
        assert mask_feat_stride % self.mask_out_stride == 0
        mask_logits = aligned_bilinear(mask_logits, int(mask_feat_stride / self.mask_out_stride))

        return mask_logits.sigmoid()
Exemplo n.º 2
0
    def postprocess(self,
                    results,
                    output_height,
                    output_width,
                    padded_im_h,
                    padded_im_w,
                    mask_threshold=0.5):
        """
        Resize the output instances.
        The input images are often resized when entering an object detector.
        As a result, we often need the outputs of the detector in a different
        resolution from its inputs.
        This function will resize the raw outputs of an R-CNN detector
        to produce outputs according to the desired output resolution.
        Args:
            results (Instances): the raw outputs from the detector.
                `results.image_size` contains the input image resolution the detector sees.
                This object might be modified in-place.
            output_height, output_width: the desired output resolution.
        Returns:
            Instances: the resized output from the model, based on the output resolution
        """
        scale_x, scale_y = (output_width / results.image_size[1],
                            output_height / results.image_size[0])
        resized_im_h, resized_im_w = results.image_size
        results = Instances((output_height, output_width),
                            **results.get_fields())

        if results.has("pred_boxes"):
            output_boxes = results.pred_boxes
        elif results.has("proposal_boxes"):
            output_boxes = results.proposal_boxes

        output_boxes.scale(scale_x, scale_y)
        output_boxes.clip(results.image_size)

        results = results[output_boxes.nonempty()]

        if results.has("pred_global_masks"):
            mask_h, mask_w = results.pred_global_masks.size()[-2:]
            factor_h = padded_im_h // mask_h
            factor_w = padded_im_w // mask_w
            assert factor_h == factor_w
            factor = factor_h
            pred_global_masks = aligned_bilinear(results.pred_global_masks,
                                                 factor)
            pred_global_masks = pred_global_masks[:, :, :resized_im_h, :
                                                  resized_im_w]
            pred_global_masks = F.interpolate(pred_global_masks,
                                              size=(output_height,
                                                    output_width),
                                              mode="bilinear",
                                              align_corners=False)
            pred_global_masks = pred_global_masks[:, 0, :, :]
            results.pred_masks = (pred_global_masks > mask_threshold).float()

        return results
Exemplo n.º 3
0
    def mask_heads_forward_with_coords(self, mask_feats, mask_feat_stride,
                                       instances):
        locations = compute_locations(mask_feats.size(2),
                                      mask_feats.size(3),
                                      stride=mask_feat_stride,
                                      device=mask_feats.device)
        n_inst = len(instances)

        im_inds = instances.im_inds
        mask_head_params = instances.mask_head_params

        N, _, H, W = mask_feats.size()

        if not self.disable_rel_coords:
            instance_locations = instances.locations
            relative_coords = instance_locations.reshape(
                -1, 1, 2) - locations.reshape(1, -1, 2)
            relative_coords = relative_coords.permute(0, 2, 1).float()
            soi = self.sizes_of_interest.float()[instances.fpn_levels]
            relative_coords = relative_coords / soi.reshape(-1, 1, 1)
            relative_coords = relative_coords.to(dtype=mask_feats.dtype)

            mask_head_inputs = torch.cat([
                relative_coords, mask_feats[im_inds].reshape(
                    n_inst, self.in_channels, H * W)
            ],
                                         dim=1)
        else:
            mask_head_inputs = mask_feats[im_inds].reshape(
                n_inst, self.in_channels, H * W)

        mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W)

        weights, biases = parse_dynamic_params(mask_head_params, self.channels,
                                               self.weight_nums,
                                               self.bias_nums)

        mask_logits = self.mask_heads_forward(mask_head_inputs, weights,
                                              biases, n_inst)

        mask_logits = mask_logits.reshape(-1, 1, H, W)

        assert mask_feat_stride >= self.mask_out_stride
        assert mask_feat_stride % self.mask_out_stride == 0
        mask_logits = aligned_bilinear(
            mask_logits, int(mask_feat_stride / self.mask_out_stride))

        return mask_logits
Exemplo n.º 4
0
    def compute_mask_prob(self, instances, pixel_embed, mask_feat_stride):
        proposal_embed = instances.proposal_embed
        proposal_margin = instances.proposal_margin
        im_inds = instances.im_inds

        dim, m_h, m_w = pixel_embed.shape[-3:]
        obj_num = proposal_embed.shape[0]
        pixel_embed = pixel_embed.permute(0, 2, 3, 1)[im_inds]

        proposal_embed = proposal_embed.view(obj_num, 1, 1, -1).expand(-1, m_h, m_w, -1)
        proposal_margin = proposal_margin.view(obj_num, 1, 1, dim).expand(-1, m_h, m_w, -1)
        mask_var = (pixel_embed - proposal_embed) ** 2
        mask_prob = torch.exp(-torch.sum(mask_var * proposal_margin, dim=3))

        assert mask_feat_stride >= self.mask_out_stride
        assert mask_feat_stride % self.mask_out_stride == 0
        mask_prob = aligned_bilinear(mask_prob.unsqueeze(1), int(mask_feat_stride / self.mask_out_stride))

        return mask_prob
Exemplo n.º 5
0
 def recover_ins2all_test(self, mask_scores, pred_instances):
     mask_scores = aligned_bilinear(mask_scores, 2)
     return mask_scores
Exemplo n.º 6
0
    def forward(self, features, gt_instances=None):
        for i, f in enumerate(self.in_features):
            if i == 0:
                x = self.refine[i](features[f])
            else:
                x_p = self.refine[i](features[f])

                target_h, target_w = x.size()[2:]
                h, w = x_p.size()[2:]
                assert target_h % h == 0
                assert target_w % w == 0
                factor_h, factor_w = target_h // h, target_w // w
                assert factor_h == factor_w
                x_p = aligned_bilinear(x_p, factor_h)
                x = x + x_p

        mask_feats = self.tower(x)

        losses = {}
        # auxiliary thing semantic loss
        if self.training and self.sem_loss_on:
            logits_pred = self.logits(
                self.seg_head(features[self.in_features[0]]))

            # compute semantic targets
            semantic_targets = []
            for per_im_gt in gt_instances:
                h, w = per_im_gt.gt_bitmasks_full.size()[-2:]
                areas = per_im_gt.gt_bitmasks_full.sum(dim=-1).sum(dim=-1)
                areas = areas[:, None, None].repeat(1, h, w)
                areas[per_im_gt.gt_bitmasks_full == 0] = INF
                areas = areas.permute(1, 2, 0).reshape(h * w, -1)
                min_areas, inds = areas.min(dim=1)
                per_im_sematic_targets = per_im_gt.gt_classes[inds] + 1
                per_im_sematic_targets[min_areas == INF] = 0
                per_im_sematic_targets = per_im_sematic_targets.reshape(h, w)
                semantic_targets.append(per_im_sematic_targets)

            semantic_targets = torch.stack(semantic_targets, dim=0)

            # resize target to reduce memory
            semantic_targets = semantic_targets[:, None, self.out_stride //
                                                2::self.out_stride,
                                                self.out_stride //
                                                2::self.out_stride]

            # prepare one-hot targets
            num_classes = logits_pred.size(1)
            class_range = torch.arange(num_classes,
                                       dtype=logits_pred.dtype,
                                       device=logits_pred.device)[:, None,
                                                                  None]
            class_range = class_range + 1
            one_hot = (semantic_targets == class_range).float()
            num_pos = (one_hot > 0).sum().float().clamp(min=1.0)

            loss_sem = sigmoid_focal_loss_jit(
                logits_pred,
                one_hot,
                alpha=self.focal_loss_alpha,
                gamma=self.focal_loss_gamma,
                reduction="sum",
            ) / num_pos
            losses['loss_sem'] = loss_sem

        return mask_feats, losses