def mask_heads_forward_with_coords(
            self, mask_feats, mask_feat_stride, instances
    ):
        # 1/8 P3 对应的原图真实坐标
        locations = compute_locations(
            mask_feats.size(2), mask_feats.size(3),
            stride=mask_feat_stride, device=mask_feats.device
        )
        n_inst = len(instances)

        # 在fcos每个点采样时,会记录所属img id,最后根据分类pos id筛选,
        # 最后在实例分割这块根据ins个数复制相应image出来的mask feature
        # 即(2, n, h, w) -> (ins, n, hxw)
        im_inds = instances.im_inds
        mask_head_params = instances.mask_head_params

        N, _, H, W = mask_feats.size()

        if not self.disable_rel_coords:
            # 之前在fcos记录了每个pos location的中心位置,在这里生成相对坐标
            # 即target所在的点变成0,其余的变成和它的相对距离
            instance_locations = instances.locations
            # (39, 1, 2) - (1, hxw, 2)
            relative_coords = instance_locations.reshape(-1, 1, 2) - locations.reshape(1, -1, 2)
            # (39, hxw, 2) --> (39, 2, hxw)
            relative_coords = relative_coords.permute(0, 2, 1).float()
            # 给每个相对距离乘以一个衰减系数,如果instance越大,即来自高层特征,
            # 则会给它的相对距离更大的衰减因子
            soi = self.sizes_of_interest.float()[instances.fpn_levels]
            relative_coords = relative_coords / soi.reshape(-1, 1, 1)
            relative_coords = relative_coords.to(dtype=mask_feats.dtype)

            mask_head_inputs = torch.cat([
                relative_coords, mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W)
            ], dim=1)
            # torch.Size([39, 10, 12880]) torch.Size([2, 8, 92, 140]) torch.Size([39, 2, 12880])
            # print(mask_head_inputs.shape, mask_feats.shape, relative_coords.shape)
        else:
            mask_head_inputs = mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W)

        mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W)

        weights, biases = parse_dynamic_params(
            mask_head_params, self.channels,
            self.weight_nums, self.bias_nums
        )

        # 现在的mask_head_inputs的每个instance包含了全图的mask feature
        # 以及所预测instance中心点的相对距离信息
        # torch.Size([1, 580, 100, 136])
        # [torch.Size([464, 10, 1, 1]), torch.Size([464, 8, 1, 1]), torch.Size([58, 8, 1, 1])]
        # [torch.Size([464]), torch.Size([464]), torch.Size([58])]
        mask_logits = self.mask_heads_forward(mask_head_inputs, weights, biases, n_inst)
        mask_logits = mask_logits.reshape(-1, 1, H, W)

        assert mask_feat_stride >= self.mask_out_stride
        assert mask_feat_stride % self.mask_out_stride == 0
        mask_logits = aligned_bilinear(mask_logits, int(mask_feat_stride / self.mask_out_stride))

        return mask_logits.sigmoid()
Exemple #2
0
    def mask_heads_forward_with_coords(self, mask_feats, mask_feat_stride,
                                       instances, out_size):
        locations = compute_locations(mask_feats.size(2),
                                      mask_feats.size(3),
                                      stride=mask_feat_stride,
                                      device=mask_feats.device)
        n_inst = len(instances)

        im_inds = instances.im_inds
        mask_head_params = instances.mask_head_params

        N, _, H, W = mask_feats.size()

        instance_locations = instances.locations
        new_locations = instance_locations / 8.  # n, 2(x, y)
        new_locations_x = new_locations[:, 0].clamp(min=0, max=W - 1) // int(
            (W / self.grid_num))
        new_locations_y = new_locations[:, 1].clamp(min=0, max=H - 1) // int(
            (H / self.grid_num))
        new_locations_ind = (new_locations_x +
                             self.grid_num * new_locations_y).to(torch.int64)

        if not self.disable_rel_coords:
            relative_coords = instance_locations.reshape(
                -1, 1, 2) - locations.reshape(1, -1, 2)
            relative_coords = relative_coords.permute(0, 2, 1).float()
            soi = self.sizes_of_interest.float()[instances.fpn_levels]
            relative_coords = relative_coords / soi.reshape(-1, 1, 1)
            relative_coords = relative_coords.to(dtype=mask_feats.dtype)

            mask_head_inputs = torch.cat([
                relative_coords, mask_feats[im_inds].reshape(
                    n_inst, self.in_channels, H * W)
            ],
                                         dim=1)
            mask_head_inputs = mask_head_inputs.reshape(n_inst, 10, H, W)
        else:
            mask_head_inputs = mask_feats[im_inds].reshape(
                n_inst, self.in_channels, H * W)
            mask_head_inputs = mask_head_inputs.reshape(
                n_inst, self.in_channels, H, W)

        weights, biases = parse_dynamic_params(mask_head_params, self.channels,
                                               self.weight_nums,
                                               self.bias_nums)
        mask_logits = self.mask_heads_forward(mask_head_inputs, weights,
                                              biases, n_inst,
                                              new_locations_ind)
        mask_logits = mask_logits.reshape(-1, 1, int(H / self.grid_num),
                                          int(W / self.grid_num))

        assert mask_feat_stride >= self.mask_out_stride
        assert mask_feat_stride % self.mask_out_stride == 0
        mask_logits = F.interpolate(mask_logits,
                                    size=out_size,
                                    mode='bilinear',
                                    align_corners=True)

        return mask_logits.sigmoid()
Exemple #3
0
 def compute_locations(self, features):
     locations = []
     for level, feature in enumerate(features):
         h, w = feature.size()[-2:]
         locations_per_level = compute_locations(h, w,
                                                 self.fpn_strides[level],
                                                 feature.device)
         locations.append(locations_per_level)
     return locations
    def mask_heads_forward_with_coords(self, mask_feats, mask_feat_stride,
                                       instances):
        locations = compute_locations(mask_feats.size(2),
                                      mask_feats.size(3),
                                      stride=mask_feat_stride,
                                      device=mask_feats.device)
        n_inst = len(instances)

        im_inds = instances.im_inds
        mask_head_params = instances.mask_head_params

        N, _, H, W = mask_feats.size()

        if not self.disable_rel_coords:
            instance_locations = instances.locations
            relative_coords = instance_locations.reshape(
                -1, 1, 2) - locations.reshape(1, -1, 2)
            relative_coords = relative_coords.permute(0, 2, 1).float()
            soi = self.sizes_of_interest.float()[instances.fpn_levels]
            relative_coords = relative_coords / soi.reshape(-1, 1, 1)
            relative_coords = relative_coords.to(dtype=mask_feats.dtype)

            mask_head_inputs = torch.cat([
                relative_coords, mask_feats[im_inds].reshape(
                    n_inst, self.in_channels, H * W)
            ],
                                         dim=1)
        else:
            mask_head_inputs = mask_feats[im_inds].reshape(
                n_inst, self.in_channels, H * W)

        mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W)

        weights, biases = parse_dynamic_params(mask_head_params, self.channels,
                                               self.weight_nums,
                                               self.bias_nums)

        mask_logits = self.mask_heads_forward(mask_head_inputs, weights,
                                              biases, n_inst)

        mask_logits = mask_logits.reshape(-1, 1, H, W)

        assert mask_feat_stride >= self.mask_out_stride
        assert mask_feat_stride % self.mask_out_stride == 0
        mask_logits = aligned_bilinear(
            mask_logits, int(mask_feat_stride / self.mask_out_stride))

        return mask_logits
Exemple #5
0
    def mask_heads_forward_with_coords(
            self, mask_feats, mask_feat_stride, instances, gt_instances=None
    ):
        n_inst = len(instances)

        im_inds = instances.im_inds
        mask_head_params = instances.mask_head_params
        instance_locations = instances.locations
        levels = instances.fpn_levels
        # 0, 1, 2, 3, 4 => P3/P4/P5/P6/P7
        # (3, 4) ==> 1
        # (1, 2) ==> 2
        # 0 ==> 4
        ind1 = (levels > self.split[2]) & (levels <= self.split[3])
        ind2 = (levels > self.split[1]) & (levels <= self.split[2])
        ind4 = (levels > self.split[0]) & (levels <= self.split[1])

        weights, biases = parse_dynamic_params(
            mask_head_params, self.channels,
            self.weight_nums, self.bias_nums, [ind1, ind2, ind4], self.concat
        )

        N, _, H, W = mask_feats.size()
        num_layers = self.num_layers
        mask_head_inputs = mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W)
        if not self.disable_rel_coords:
            base_locations = compute_locations(
                mask_feats.size(2), mask_feats.size(3),
                stride=mask_feat_stride, device=mask_feats.device
            )
            mask_head_inputs = self.add_locaions_info(base_locations, instances, mask_head_inputs)
            mask_head_inputs = mask_head_inputs.reshape(n_inst, self.in_channels + 2, H, W)
        else:
            mask_head_inputs = mask_head_inputs.reshape(n_inst, self.in_channels, H, W)

        def get_loaction_weights_bias(level_ind, locations, i_weight, i_bias, grid_num=4, grid_inside_num=3):
            new_locations = locations / 4.  # n, 2(x, y)
            new_locations[:, 0] = new_locations[:, 0].clamp(min=0, max=2*W - 1)
            new_locations[:, 1] = new_locations[:, 1].clamp(min=0, max=2*H - 1)
            new_locations_x = new_locations[:, 0] // int((2*W / grid_num))
            new_locations_y = new_locations[:, 1] // int((2*H / grid_num))
            new_locations_ind = (new_locations_x + grid_num * new_locations_y).to(torch.int64)
            if not self.concat or not len(locations):
                return level_ind, new_locations_ind, i_weight, i_bias, None
            new_inside_locations = torch.zeros_like(new_locations)
            new_inside_locations[:, 0] = new_locations[:, 0] % int((2*W / grid_num))
            new_inside_locations[:, 1] = new_locations[:, 1] % int((2*H / grid_num))
            new_locations_inside_x = new_inside_locations[:, 0] // round((2*W / (grid_num * grid_inside_num)) + 0.5)
            new_locations_inside_y = new_inside_locations[:, 1] // round((2*H / (grid_num * grid_inside_num)) + 0.5)
            relative_inside_location = (new_locations_inside_x + grid_inside_num * new_locations_inside_y).to(
                torch.int64)
            assert (relative_inside_location < grid_inside_num ** 2).all(), \
                (W, H, new_inside_locations, new_locations_inside_x, new_locations_inside_y, relative_inside_location)
            # import time
            # time1 = time.time()
            # cate part
            grid_map = torch.from_numpy(get_grid_map(grid_inside_num)).to(device=mask_feats.device)
            new_locations_ind = new_locations_ind.cpu().numpy()
            maps = grid_map[relative_inside_location].clone()  # n, 9, 2
            n, _, _ = maps.shape
            maps[:, :, 0] = maps[:, :, 0] + new_locations_x.repeat(9).reshape(9, -1).T
            maps[:, :, 1] = maps[:, :, 1] + new_locations_y.repeat(9).reshape(9, -1).T
            maps = maps.reshape(-1, 2)
            ind_valid = (maps[:, 0] >= 0) & (maps[:, 0] < grid_num) & (maps[:, 1] >= 0) & (maps[:, 1] < grid_num)
            maps = maps[ind_valid]
            final_locations_ind = (maps[:, 0] + grid_num * maps[:, 1]).to(dtype=torch.int64, device=mask_feats.device)
            param_ind = ind_valid.reshape(n, 9)
            param_ind = param_ind.T*torch.arange(1, n+1).to(device=mask_feats.device)
            param_ind = param_ind.T.flatten()
            param_ind = param_ind[param_ind > 0] - 1
            param_ind = param_ind.to(dtype=torch.int64, device=mask_feats.device)
            gt_ind = torch.arange(0, n*9)[ind_valid].to(dtype=torch.int64, device=mask_feats.device)
            if not len(param_ind):
                return level_ind, new_locations_ind, i_weight, i_bias, None
            for l in range(num_layers):
                assert len(param_ind), param_ind
                i_weight[l] = i_weight[l][param_ind]
                i_bias[l] = i_bias[l][param_ind]
                if l < num_layers - 1:
                    n, c, _, _, _ = i_weight[l].shape
                    i_weight[l] = i_weight[l].reshape(n * c, -1, 1, 1)
                    i_bias[l] = i_bias[l].reshape(n * c)
            return param_ind, final_locations_ind, i_weight, i_bias, gt_ind

        param_ind_2, new_ind2, new_weights_2, new_biases_2, gt_ind_2 = get_loaction_weights_bias(
            ind2, instance_locations[ind2], weights[1], biases[1], grid_num=self.grid_num[1]
        )
        param_ind_4, new_ind4, new_weights_4, new_biases_4, gt_ind_4 = get_loaction_weights_bias(
            ind4, instance_locations[ind4], weights[2], biases[2], grid_num=self.grid_num[2]
        )

        mask_head_inputs1 = mask_head_inputs[ind1]
        if not self.concat:
            mask_head_inputs2 = mask_head_inputs[ind2]
            mask_head_inputs4 = mask_head_inputs[ind4]
            inds_list = [ind1, ind2, ind4]
        else:
            mask_head_inputs2 = mask_head_inputs[param_ind_2]
            mask_head_inputs4 = mask_head_inputs[param_ind_4]
            inds_list = [(ind1, ind1), (ind2, param_ind_2), (ind4, param_ind_4)]

        n_inst1 = mask_head_inputs1.shape[0]
        n_inst2 = mask_head_inputs2.shape[0]
        n_inst4 = mask_head_inputs4.shape[0]

        if gt_instances is not None:
            gt_inds = instances.gt_inds
            gt_bitmasks = torch.cat([per_im.gt_bitmasks for per_im in gt_instances])  # 1/4
            gt_boxes = torch.cat([per_im.gt_boxes.tensor for per_im in gt_instances])  # 1

            if not self.concat:
                gt_inds1 = gt_inds[ind1]
                crop_gt_bitmasks1, out_size1 = self.crop_and_expand(gt_bitmasks, gt_boxes, grid_num=self.grid_num[0])
                gt_bitmasks1 = crop_gt_bitmasks1[gt_inds1].unsqueeze(dim=1).to(dtype=mask_feats.dtype)
                gt_inds2 = gt_inds[ind2]
                crop_gt_bitmasks2, out_size2 = self.crop_and_expand(gt_bitmasks, gt_boxes, grid_num=self.grid_num[1])
                gt_inds4 = gt_inds[ind4]
                crop_gt_bitmasks4, out_size4 = self.crop_and_expand(gt_bitmasks, gt_boxes, grid_num=self.grid_num[2])
                gt_bitmasks2 = crop_gt_bitmasks2[gt_inds2]
                gt_bitmasks4 = crop_gt_bitmasks4[gt_inds4]
            else:
                gt_inds1 = gt_inds[ind1]
                crop_gt_bitmasks1, out_size1 = self.crop_and_expand_concate(gt_bitmasks, gt_boxes, grid_num=self.grid_num[0])
                gt_bitmasks1 = crop_gt_bitmasks1[gt_inds1].unsqueeze(dim=1).to(dtype=mask_feats.dtype)
                crop_gt_bitmasks2 = self.crop_and_expand_concate(gt_bitmasks, gt_boxes, grid_num=self.grid_num[1])
                crop_gt_bitmasks4 = self.crop_and_expand_concate(gt_bitmasks, gt_boxes, grid_num=self.grid_num[2])
                gt_inds2 = gt_inds[ind2]
                gt_inds4 = gt_inds[ind4]
                gt_bitmasks2 = [crop_gt_bitmasks2[ind] for ind in gt_inds2.tolist()]
                gt_bitmasks4 = [crop_gt_bitmasks4[ind] for ind in gt_inds4.tolist()]
                gt_bitmasks2 = torch.cat(gt_bitmasks2)[gt_ind_2] if len(gt_inds2) else None
                gt_bitmasks4 = torch.cat(gt_bitmasks4)[gt_ind_4] if len(gt_inds4) else None
                out_size2 = gt_bitmasks2.shape[1:] if len(gt_inds2) else None
                out_size4 = gt_bitmasks4.shape[1:] if len(gt_inds4) else None
            gt_bitmasks2 = gt_bitmasks2.unsqueeze(dim=1).to(dtype=mask_feats.dtype) if len(gt_inds2) else None
            gt_bitmasks4 = gt_bitmasks4.unsqueeze(dim=1).to(dtype=mask_feats.dtype) if len(gt_inds4) else None
            gt_bitmasks_list = [gt_bitmasks1, gt_bitmasks2, gt_bitmasks4]
        else:
            gt_bitmasks_list = []
            out_size1 = (int(H * 2), int(W * 2))
            out_size2 = (H, W)
            out_size4 = (int(H / 2), int(W / 2))
            # out_size1 = (int(H * 2), int(W * 2))
            # out_size2 = (int(H * 2), int(W * 2))
            # out_size4 = (int(H * 2), int(W * 2))
        out_size = {1: out_size1, 2: out_size2, 4: out_size4}

        mask_logits1 = self.mask_heads_forward(
            mask_head_inputs1, out_size[self.grid_num[0]], weights[0], biases[0], n_inst1, [], self.grid_num[0])
        mask_logits2 = self.mask_heads_forward(
            mask_head_inputs2, out_size[self.grid_num[1]], weights[1], biases[1], n_inst2, new_ind2, self.grid_num[1])
        mask_logits4 = self.mask_heads_forward(
            mask_head_inputs4, out_size[self.grid_num[2]], weights[2], biases[2], n_inst4, new_ind4, self.grid_num[2])

        mask_logits_list = [mask_logits1, mask_logits2, mask_logits4]

        return mask_logits_list, gt_bitmasks_list, inds_list, [None, gt_ind_2, gt_ind_4]