Exemplo n.º 1
0
    def forward_select(self, x):
        tau = self.get_tau()
        lr_summary = self.lr_raw.detach().item(
        ) if self.auto_lr else self.lr_raw
        add_summary('tau',
                    tau=tau,
                    avg=self.alpha_raw.abs().mean().detach().item(),
                    lr_raw=lr_summary)
        alpha_tensor = self.gumbel_softmax(self.alpha_raw, temperature=tau)

        ys = []
        for block_i in range(self.split):
            start = block_i * self.expand_num
            end = (block_i + 1) * self.expand_num
            bias = self.bias[start:end] if self.bias is not None else None
            for i, dilation in enumerate(self.dilation_choice):
                if alpha_tensor[block_i, i] == 1.:
                    y = F.conv2d(
                        x, self.weight[start:end] * alpha_tensor[block_i, i],
                        bias, self.stride, dilation, dilation, 1)
                    break

            ys.append(y)
        y = torch.cat(ys, dim=1)
        return y
Exemplo n.º 2
0
    def forward(self, img, img_k):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """
        img_q = img
        if every_n_local_step(self.train_cfg.get('vis_freq', 100)):
            # add_image_summary('origin', img[0], type='0to1')
            add_image_summary('query', img_q[0], type='mean0var')
            add_image_summary('key', img_k[0], type='mean0var')

        # compute query features
        q = self.encoder_q(img_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)

        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder

            # shuffle for making use of BN
            im_k, idx_unshuffle = self._batch_shuffle_ddp(img_k)

            k = self.encoder_k(im_k)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)

            # undo shuffle
            k = self._batch_unshuffle_ddp(k, idx_unshuffle)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        # logits: Nx(1+len_queue)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.temperature

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        # dequeue and enqueue
        self._dequeue_and_enqueue(k)

        if self.training:
            loss = self.criterion(logits, labels)
            acc1, acc5 = accuracy(logits, labels, topk=(1, 5))
            add_summary('acc', top1=acc1[0], top5=acc5[0])
            return dict(moco_loss=loss)

        return logits, labels
Exemplo n.º 3
0
    def gumbel_softmax(self,
                       logits,
                       temperature,
                       topk=1,
                       summary_name=('raw', 'params')):
        """
        input: (..., choice_num)
        return: (..., choice_num), an one-zero vector
        """
        if 'raw' in summary_name:
            logits = logits * self.lr_raw
        else:
            logits = logits * self.param_beta_multiply

        if self.fixed_param is not None:
            logits_softmax = self.fixed_param
        else:
            logits_softmax = logits.softmax(-1)
        if self.training:
            summary = {}
            logits_summary = logits.view(-1)
            for i, l in enumerate(logits_summary):
                summary['dyn_dila_{}_{}'.format(self.count,
                                                i)] = l.detach().item()
            add_summary(summary_name[0], **summary)

            summary = {}
            logits_summary = logits_softmax.view(-1)
            for i, l in enumerate(logits_summary):
                summary['dyn_dila_{}_{}'.format(self.count,
                                                i)] = l.detach().item()
            add_summary(summary_name[1], **summary)

        if self.end_after_end is False:
            end_search = False
        elif self.end_after_end is True:
            end_search = (get_epoch() > self.tau_end_epoch)
        else:
            end_search = (get_epoch() > self.end_after_end)
        if self.training and not end_search and self.fixed_param is None:
            empty_tensor = logits.new_zeros(logits.size())
            U = nn.init.uniform_(empty_tensor)
            gumbel_sample = -torch.autograd.Variable(
                torch.log(-torch.log(U + 1e-20) + 1e-20))
            y = F.softmax((logits_softmax.log() + gumbel_sample) / temperature,
                          dim=-1)
        else:
            y = logits
        shape = y.size()
        _, inds = y.topk(topk, dim=-1)
        y_hard = torch.zeros_like(y).view(-1, shape[-1])
        y_hard.scatter_(1, inds.view(-1, topk), 1)
        y_hard = y_hard.view(*shape)
        if self.training and not end_search and self.fixed_param is None:
            return ((y_hard - y).detach() + y) / topk
        return y_hard / topk
Exemplo n.º 4
0
    def sample(self,
               assign_result,
               bboxes,
               gt_bboxes,
               gt_labels=None,
               **kwargs):
        """Sample positive and negative bboxes.

        This is a simple implementation of bbox sampling given candidates,
        assigning results and ground truth bboxes.

        Args:
            assign_result (:obj:`AssignResult`): Bbox assigning results.
            bboxes (Tensor): Boxes to be sampled from.
            gt_bboxes (Tensor): Ground truth bboxes.
            gt_labels (Tensor, optional): Class labels of ground truth bboxes.

        Returns:
            :obj:`SamplingResult`: Sampling result.
        """
        bboxes = bboxes[:, :4]

        gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
        if self.add_gt_as_proposals and len(gt_bboxes) > 0:
            bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
            assign_result.add_gt_(gt_labels)
            gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
            gt_flags = torch.cat([gt_ones, gt_flags])
            add_summary('proposal', num_gts=float(gt_ones.size(0)))

        num_expected_pos = int(self.num * self.pos_fraction)
        pos_inds = self.pos_sampler._sample_pos(assign_result,
                                                num_expected_pos,
                                                bboxes=bboxes,
                                                **kwargs)
        # We found that sampled indices have duplicated items occasionally.
        # (may be a bug of PyTorch)
        pos_inds = pos_inds.unique()
        num_sampled_pos = pos_inds.numel()
        num_expected_neg = self.num - num_sampled_pos
        if self.neg_pos_ub >= 0:
            _pos = max(1, num_sampled_pos)
            neg_upper_bound = int(self.neg_pos_ub * _pos)
            if num_expected_neg > neg_upper_bound:
                num_expected_neg = neg_upper_bound
        neg_inds = self.neg_sampler._sample_neg(assign_result,
                                                num_expected_neg,
                                                bboxes=bboxes,
                                                **kwargs)
        neg_inds = neg_inds.unique()

        return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
                              assign_result, gt_flags)
Exemplo n.º 5
0
 def label_metrics(cls_score, labels):
     fg_mask = labels > 0
     num_fg = torch.sum(fg_mask)
     prediction = torch.argmax(cls_score, dim=1)
     correct = (prediction == labels).float()
     accuracy = torch.mean(correct).item()
     fg_label_pred = torch.argmax(cls_score[fg_mask], dim=1)
     num_zero = float(torch.sum(fg_label_pred == 0))
     empty_fg = num_fg == 0
     summary = {'accuracy': accuracy}
     if not empty_fg:
         fg_accuracy = torch.mean(correct[fg_mask]).item()
         false_negative = num_zero / float(num_fg)
         summary['fg_accuracy'] = fg_accuracy
         summary['false_negative'] = false_negative
     add_summary('fast_rcnn', **summary)
Exemplo n.º 6
0
    def proposal_metrics(overlaps):
        """Add summaries for proposals.

            Args:
                overlap: nxm, #gt x #bbox
            """
        # find best roi for each gt, for summary only
        best_iou, _ = torch.max(overlaps, dim=1)  # (gt,)
        mean_best_iou = torch.mean(best_iou).item()
        summaries = {'mean_best_iou': mean_best_iou}

        if best_iou.size(0) >= 0:
            for th in [0.3, 0.5, 0.7]:
                best_over_th = float(torch.sum(best_iou >= th).item()) / float(
                    best_iou.size(0))
                summaries['best_over_{}'.format(th)] = best_over_th
        add_summary('proposal', **summaries)
Exemplo n.º 7
0
    def forward(self, feats, rois, roi_scale_factor=None):
        if len(feats) == 1:
            return self.roi_layers[0](feats[0], rois)

        out_size = self.roi_layers[0].out_size
        num_levels = len(feats)
        target_lvls = self.map_roi_levels(rois, num_levels)
        roi_feats = feats[0].new_zeros(rois.size(0), self.out_channels,
                                       *out_size)
        if roi_scale_factor is not None:
            rois = self.roi_rescale(rois, roi_scale_factor)
        for i in range(num_levels):
            inds = target_lvls == i
            self.summary_dict['num_roi_in_level_{}'.format(i)] = torch.sum(
                inds).float().item()
            if inds.any():
                rois_ = rois[inds, :]
                roi_feats_t = self.roi_layers[i](feats[i], rois_)
                roi_feats[inds] = roi_feats_t
        add_summary('roi_extractor', **self.summary_dict)
        return roi_feats
Exemplo n.º 8
0
    def target_single_image(self, gt_boxes, gt_labels, feat_shapes,
                            obj_sizes_of_interest):
        """

        Args:
            gt_boxes: tensor, tensor <=> img, (num_gt, 4).
            gt_labels: tensor, tensor <=> img, (num_gt,).
            feat_shape: list(tuple). tuple <=> level.
            obj_sizes_of_interest: tensor, (level_num, 2).

        Returns:
            all_heatmap: tensor, tensor <=> img, (80, h*w for all levels).
            all_wh: tensor, tensor <=> img, (max_obj*level_num, 2).
            all_reg_mask: tensor, tensor <=> img, (max_obj*level_num,).
            all_ind: tensor, tensor <=> img, (max_obj*level_num,).
            all_center_location: tensor or None, tensor <=> img, (max_obj*level_num, 2).
        """
        level_size = len(self.fpn_strides)
        all_heatmap, all_wh, all_reg_mask, all_ind, all_center_location = [], [], [], [], []
        max_wh_target = torch.max(gt_boxes[:, 3] - gt_boxes[:, 1],
                                  gt_boxes[:, 2] -
                                  gt_boxes[:, 0]).unsqueeze(-1).repeat(
                                      1, level_size)
        is_cared_in_the_level = \
            (max_wh_target >= obj_sizes_of_interest[:, 0]) & \
            (max_wh_target <= obj_sizes_of_interest[:, 1])  # (gt_num, level_num)

        cared_gt_num_per_level = {}
        for lvl in range(level_size):
            cared_gt_num_per_level['num_lv{}'.format(lvl)] = \
                is_cared_in_the_level[:, lvl].sum().item()
        add_summary('centernet', **cared_gt_num_per_level)

        for lvl in range(level_size):
            # get target for a single level of a single image.
            output_h, output_w = feat_shapes[lvl]
            heatmap = gt_boxes.new_zeros(
                (self.num_classes, output_h, output_w))
            wh = gt_boxes.new_zeros((self.max_objs, 2))
            reg_mask = gt_boxes.new_zeros((self.max_objs, ), dtype=torch.uint8)
            ind = gt_boxes.new_zeros((self.max_objs, ), dtype=torch.long)

            center_location = None
            if self.use_giou:
                center_location = gt_boxes.new_zeros((self.max_objs, 2))

            gt_boxes_in_lvl = gt_boxes[is_cared_in_the_level[:, lvl]]
            if gt_boxes_in_lvl.size(0) > 0:
                gt_boxes_in_lvl /= self.fpn_strides[lvl]
                gt_boxes_in_lvl[:, [0, 2]] = torch.clamp(
                    gt_boxes_in_lvl[:, [0, 2]], 0, output_w - 1)
                gt_boxes_in_lvl[:, [1, 3]] = torch.clamp(
                    gt_boxes_in_lvl[:, [1, 3]], 0, output_h - 1)
                hs = gt_boxes_in_lvl[:, 3] - gt_boxes_in_lvl[:, 1]
                ws = gt_boxes_in_lvl[:, 2] - gt_boxes_in_lvl[:, 0]

                if every_n_local_step(500):
                    add_histogram_summary('mlct_head_hs_lv{}'.format(lvl),
                                          hs.detach().cpu(),
                                          collect_type='none')
                    add_histogram_summary('mlct_head_ws_lv{}'.format(lvl),
                                          ws.detach().cpu(),
                                          collect_type='none')

                for k in range(gt_boxes_in_lvl.shape[0]):
                    cls_id = gt_labels[k] - 1
                    h, w = hs[k], ws[k]
                    if h > 0 and w > 0:
                        radius = gaussian_radius((h.ceil(), w.ceil()))
                        radius = max(0, int(radius.item()))
                        center = gt_boxes.new_tensor([
                            (gt_boxes_in_lvl[k, 0] + gt_boxes_in_lvl[k, 2]) /
                            2,
                            (gt_boxes_in_lvl[k, 1] + gt_boxes_in_lvl[k, 3]) / 2
                        ])
                        # no peak will fall between pixels
                        ct_int = center.to(torch.int)
                        draw_umich_gaussian(heatmap[cls_id], ct_int, radius)
                        # directly predict the width and height
                        wh[k] = wh.new_tensor([1. * w, 1. * h])
                        ind[k] = ct_int[1] * output_w + ct_int[0]
                        if self.use_giou:
                            center_location[k] = center
                        reg_mask[k] = 1

            all_heatmap.append(heatmap.view(heatmap.shape[0], -1))
            all_wh.append(wh)
            all_reg_mask.append(reg_mask)
            all_ind.append(ind)
            all_center_location.append(center_location)

        all_heatmap, all_reg_mask, all_ind = [
            torch.cat(x, dim=-1) for x in [all_heatmap, all_reg_mask, all_ind]
        ]
        all_wh = torch.cat(all_wh, dim=0)
        if self.use_giou:
            all_center_location = torch.cat(all_center_location, dim=0)

        return all_heatmap, all_wh, all_reg_mask, all_ind, all_center_location
Exemplo n.º 9
0
    def forward(self, feats):
        # print(F.softmax(self.search_alphas, dim=-1), F.softmax(self.search_betas))
        pre_feats = []
        for i, op in enumerate(self.pre_convs):
            pre_feats.append(op(feats[i]))
        # if self.do_search and self.search_op_gumbel_softmax:
        #     samples_weight = self.gumbel_softmax(self.search_alphas, self.tau)
        edge_idx, beta_start_idx = 0, 0
        alpha_summary, beta_summary = {}, {}
        alphas_softmax = F.softmax(self.search_alphas * self.multiply, dim=-1)
        for level_i, ops_level_i in enumerate(self.ops):
            post_feats = [[] for _ in range(self.node_num_per_level[level_i])]
            for j, op in enumerate(ops_level_i):
                ops_weight = alphas_softmax[edge_idx]
                for idx, w in enumerate(ops_weight):
                    alpha_summary['edge_{}_{}'.format(edge_idx, idx)] = w.item()

                if self.do_search and self.search_op_gumbel_softmax:
                    if j in self.alpha_buffer[level_i]:
                        ops_weight = self.alpha_buffer[level_i][j].to(ops_weight.device)
                    else:
                        ops_weight = self.gumbel_softmax(
                            self.search_alphas[edge_idx] * self.multiply, self.tau, topk=1)
                        self.alpha_buffer[level_i][j] = torch.tensor(ops_weight)

                post_feats[self.out_stride_for_edge[edge_idx]].append(
                    op(pre_feats[self.in_stride_for_edge[edge_idx]], ops_weight,
                       'edge_{}_'.format(edge_idx)))
                edge_idx += 1

            for stride_j, post_feats_n in enumerate(post_feats):
                beta_end_idx = beta_start_idx + len(post_feats_n)
                if not self.is_rebuilded:
                    edges_weight = F.softmax(self.search_betas[beta_start_idx:beta_end_idx])
                    for idx, w in enumerate(edges_weight):
                        beta_summary['edge_{}'.format(beta_start_idx + idx)] = w.item()

                    # if every_n_local_step(200):
                    #     for k, feat in enumerate(post_feats_n):
                    #         add_histogram_summary('edge_{}'.format(beta_start_idx + k),
                    #                               feat.detach().cpu())
                    if self.do_search and self.search_edge_gumbel_softmax:
                        if stride_j in self.beta_buffer[level_i]:
                            edges_weight = self.beta_buffer[level_i][stride_j].to(
                                edges_weight.device)
                        else:
                            if self.has_beta:
                                edges_weight = self.gumbel_softmax(
                                    self.search_betas[beta_start_idx:beta_end_idx], self.tau,
                                    topk=2)
                            else:
                                edges_weight = self.gumbel_softmax(
                                    self.search_alphas.max(-1)[0][beta_start_idx:beta_end_idx] * \
                                        self.multiply, self.tau, topk=2)
                            self.beta_buffer[level_i][stride_j] = torch.tensor(edges_weight)
                else:
                    edges_weight = [feats[0].new_ones((1,)) for _ in post_feats_n]
                post_feats[stride_j] = sum(post_feat * w for i, (post_feat, w) in enumerate(
                    zip(post_feats_n, edges_weight)))
                beta_start_idx = beta_end_idx

            pre_feats = post_feats

        if not self.is_rebuilded:
            add_summary('alphas', **alpha_summary)
            add_summary('betas', **beta_summary)
            raw_summary = {'tau': self.tau,
                           'alpha': torch.mean(torch.abs(self.search_alphas))}
            if self.has_beta:
                raw_summary['beta'] = torch.mean(torch.abs(self.search_betas))
            add_summary('raw', **raw_summary)

        return self.last_bn(post_feats[0])
Exemplo n.º 10
0
    def ttf_target_single(self, gt_boxes, gt_labels, feat_shape):
        """

        Args:
            gt_boxes: tensor, tensor <=> img, (num_gt, 4).
            gt_labels: tensor, tensor <=> img, (num_gt,).
            feat_shape: tuple.

        Returns:
            heatmap: tensor, tensor <=> img, (80, h, w).
            box_target: tensor, tensor <=> img, (4, h, w).
            reg_weight: tensor, same as box_target
        """
        output_h_b1, output_w_b1, output_h_b2, output_w_b2 = feat_shape
        heatmap_channel = self.num_fg

        heatmap_b1 = gt_boxes.new_zeros(
            (heatmap_channel, output_h_b1, output_w_b1))
        fake_heatmap_b1 = gt_boxes.new_zeros((output_h_b1, output_w_b1))
        box_target_b1 = gt_boxes.new_ones((4, output_h_b1, output_w_b1)) * -1
        reg_weight_b1 = gt_boxes.new_zeros((1, output_h_b1, output_w_b1))
        heatmap_b2 = gt_boxes.new_zeros(
            (heatmap_channel, output_h_b2, output_w_b2))
        fake_heatmap_b2 = gt_boxes.new_zeros((output_h_b2, output_w_b2))
        box_target_b2 = gt_boxes.new_ones((4, output_h_b2, output_w_b2)) * -1
        reg_weight_b2 = gt_boxes.new_zeros((1, output_h_b2, output_w_b2))

        boxes_areas_log = self.bbox_areas(gt_boxes).log()
        boxes_area_topk_log, boxes_ind = torch.topk(boxes_areas_log,
                                                    boxes_areas_log.size(0))

        gt_boxes = gt_boxes[boxes_ind]
        gt_labels = gt_labels[boxes_ind]

        if self.level_base_area:
            gt_b1_idx = boxes_area_topk_log >= math.log(self.b1_min_length**2)
            gt_b2_idx = boxes_area_topk_log <= math.log(self.b2_max_length**2)
        else:
            gt_wh = torch.cat([
                gt_boxes[..., [2]] - gt_boxes[..., [0]],
                gt_boxes[..., [3]] - gt_boxes[..., [1]]
            ],
                              dim=-1)
            if self.level_cover:
                gt_b1_idx = gt_wh.max(-1)[0] >= self.b1_min_length
                gt_b2_idx = gt_wh.min(-1)[0] <= self.b2_max_length
            elif self.level_mix:
                gt_b1_idx = boxes_area_topk_log >= math.log(self.b1_min_length
                                                            **2)
                gt_b2_idx = gt_wh.min(-1)[0] <= self.b2_max_length
            elif self.level_long:
                gt_b1_idx = gt_wh.max(-1)[0] >= self.b1_min_length
                gt_b2_idx = gt_wh.max(-1)[0] <= self.b2_max_length
            else:
                gt_b1_idx = gt_wh.min(-1)[0] >= self.b1_min_length
                gt_b2_idx = gt_wh.max(-1)[0] <= self.b2_max_length

        add_summary('gt_num',
                    b1=gt_b1_idx.sum().cpu().item(),
                    b2=gt_b2_idx.sum().cpu().item())
        heatmap_b1, box_target_b1, reg_weight_b1 = self.ttf_target_single_single(
            heatmap_b1,
            box_target_b1,
            reg_weight_b1,
            fake_heatmap_b1,
            boxes_area_topk_log[gt_b1_idx],
            gt_boxes[gt_b1_idx],
            gt_labels[gt_b1_idx],
            boxes_ind[gt_b1_idx], [output_h_b1, output_w_b1],
            self.down_ratio_b1,
            idx=0)

        heatmap_b2, box_target_b2, reg_weight_b2 = self.ttf_target_single_single(
            heatmap_b2,
            box_target_b2,
            reg_weight_b2,
            fake_heatmap_b2,
            boxes_area_topk_log[gt_b2_idx],
            gt_boxes[gt_b2_idx],
            gt_labels[gt_b2_idx],
            boxes_ind[gt_b2_idx], [output_h_b2, output_w_b2],
            self.down_ratio_b2,
            idx=1)

        return heatmap_b1, heatmap_b2, box_target_b1, box_target_b2, reg_weight_b1, reg_weight_b2
Exemplo n.º 11
0
    def forward(self, all_locations, gt_boxes, gt_labels, img_metas):
        """

        Args:
            all_locations: list(tensor). tensor <=> level, (hi * wi, 2). locations are fixed and
                only depends on the size of feature map.
            gt_boxes: list(tensor). tensor <=> image, (gt_num, 4).
            gt_labels: list(tensor). tensor <=> image, (gt_num,).
            img_metas: list(dict).

        Returns:
            labels_level_first: list(tensor). tensor <=> level, (batch * hi * wi,).
            reg_targets_level_first: list(tensor). tensor <=> level, (batch * hi * wi, 4).
        """
        self.reset_summary()

        expanded_obj_sizes_of_interest = []
        for i, locations_per_level in enumerate(all_locations):
            obj_sizes_of_interest_per_level = locations_per_level.new_tensor(
                self.obj_sizes_of_interest[i])
            expanded_obj_sizes_of_interest.append(
                # (2,) => (hi * wi, 2)
                obj_sizes_of_interest_per_level[None].expand(
                    len(locations_per_level), -1))

        num_points_per_level = [
            len(locations_per_level) for locations_per_level in all_locations
        ]
        locations_all_level = torch.cat(all_locations, dim=0)
        # (h*w for all levels, 2), indicating the size range of a location.
        expanded_obj_sizes_of_interest = torch.cat(
            expanded_obj_sizes_of_interest, dim=0)

        labels, reg_targets = multi_apply(
            self.forward_single_image,
            gt_boxes,
            gt_labels,
            locations=locations_all_level,
            obj_sizes_of_interest=expanded_obj_sizes_of_interest,
        )

        add_summary('fcos_head', average_factor=len(gt_boxes), **self.summary)

        with torch.no_grad():
            for i in range(len(labels)):
                # for each image.
                labels[i] = torch.split(labels[i], num_points_per_level, dim=0)
                reg_targets[i] = torch.split(reg_targets[i],
                                             num_points_per_level,
                                             dim=0)

            labels_level_first, reg_targets_level_first = [], []
            for level in range(len(all_locations)):
                labels_level_first.append(
                    torch.cat(
                        [labels_per_im[level] for labels_per_im in labels],
                        dim=0).detach())
                reg_targets_level_first.append(
                    torch.cat([
                        reg_targets_per_im[level]
                        for reg_targets_per_im in reg_targets
                    ],
                              dim=0).detach())

            return labels_level_first, reg_targets_level_first
Exemplo n.º 12
0
    def ttf_target_single(self, gt_boxes, gt_labels, pad_shape):
        """

        Args:
            gt_boxes: tensor, tensor <=> img, (num_gt, 4).
            gt_labels: tensor, tensor <=> img, (num_gt,).
            pad_shape: tuple.

        Returns:
            heatmap: tensor, tensor <=> img, (80, h, w).
            box_target: tensor, tensor <=> img, (4, h, w).
            reg_weight: tensor, same as box_target
        """
        heatmap_channel = self.num_fg

        boxes_areas_log = self.bbox_areas(gt_boxes).log()
        boxes_area_topk_log, boxes_ind = torch.topk(boxes_areas_log,
                                                    boxes_areas_log.size(0))

        gt_boxes = gt_boxes[boxes_ind]
        gt_labels = gt_labels[boxes_ind]

        if self.auto_range:
            anchors_range = gt_boxes.new_tensor(self.length_range).view(
                1, -1, 1)
            anchors_range = anchors_range.repeat(len(self.anchor_ratios), 1, 4)
            anchors_range[:, :, [0, 1]] = -anchors_range[:, :, [0, 1]] / 2
            anchors_range[:, :, [2, 3]] = anchors_range[:, :, [2, 3]] / 2
            anchor_ratio_scale = torch.sqrt(
                gt_boxes.new_tensor(self.anchor_ratios)).view(-1, 1, 1)
            anchors_range[:, :, [0, 2]] /= anchor_ratio_scale
            anchors_range[:, :, [1, 3]] *= anchor_ratio_scale
            anchors_range = anchors_range.view(-1, 4)

            ct_align_boxes = gt_boxes.new_tensor(gt_boxes)
            ct_align_boxes[:, [0, 2]] -= (ct_align_boxes[:, [0]] +
                                          ct_align_boxes[:, [2]]) / 2
            ct_align_boxes[:, [1, 3]] -= (ct_align_boxes[:, [1]] +
                                          ct_align_boxes[:, [3]]) / 2
            overlaps = bbox_overlaps(ct_align_boxes, anchors_range)
            _, ind = overlaps.max(1)
            ind = ind % len(self.length_range)

        heatmap, fake_heatmap, box_target, reg_weight = [], [], [], []
        output_hs, output_ws, gt_level_idx = [], [], []
        for i, (down_ratio) in enumerate(self.down_ratio):
            down_ratio = self.get_down_ratio(down_ratio)
            output_h, output_w = [shape // down_ratio for shape in pad_shape]
            heatmap.append(
                gt_boxes.new_zeros((heatmap_channel, output_h, output_w)))
            fake_heatmap.append(gt_boxes.new_zeros((output_h, output_w)))
            box_target.append(gt_boxes.new_ones((4, output_h, output_w)) * -1)
            reg_weight.append(gt_boxes.new_zeros((1, output_h, output_w)))

            output_hs.append(output_h)
            output_ws.append(output_w)
            if not self.auto_range:
                gt_level_idx.append((boxes_area_topk_log >= math.log(
                    self.length_range[i][0]**2))
                                    & (boxes_area_topk_log <= math.log(
                                        self.length_range[i][1]**2)))
            else:
                gt_level_idx.append(ind == i)

        if len(gt_level_idx) == 2:
            add_summary('gt_num',
                        b1=gt_level_idx[0].sum().cpu().item(),
                        b2=gt_level_idx[1].sum().cpu().item())
        elif len(gt_level_idx) == 3:
            add_summary('gt_num',
                        b1=gt_level_idx[0].sum().cpu().item(),
                        b2=gt_level_idx[1].sum().cpu().item(),
                        b3=gt_level_idx[2].sum().cpu().item())

        heatmap, box_target, reg_weight = multi_apply(
            self.ttf_target_single_single,
            heatmap,
            box_target,
            reg_weight,
            fake_heatmap,
            gt_level_idx,
            output_hs,
            output_ws,
            self.down_ratio,
            gt_boxes=gt_boxes,
            gt_labels=gt_labels,
            boxes_ind=boxes_ind,
            boxes_area_topk_log=boxes_area_topk_log)

        return heatmap, box_target, reg_weight
Exemplo n.º 13
0
    def forward_train(self,
                      img,
                      img_meta,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None,
                      gt_masks=None,
                      proposals=None):
        """
        Args:
            img (Tensor): of shape (N, C, H, W) encoding input images.
                Typically these should be mean centered and std scaled.

            img_meta (list[dict]): list of image info dict where each dict has:
                'img_shape', 'scale_factor', 'flip', and may also contain
                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
                For details on the values of these keys see
                `mmdet/datasets/pipelines/formatting.py:Collect`.

            gt_bboxes (list[Tensor]): each item are the truth boxes for each
                image in [tl_x, tl_y, br_x, br_y] format.

            gt_labels (list[Tensor]): class indices corresponding to each box

            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
                boxes can be ignored when computing the loss.

            gt_masks (None | Tensor) : true segmentation masks for each box
                used if the architecture supports a segmentation task.

            proposals : override rpn proposals with custom proposals. Use when
                `with_rpn` is False.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """
        x = self.extract_feat(img)

        losses = dict()
        if every_n_local_step(self.train_cfg.get('vis_every_n_iters', 2000)):
            add_image_summary(
                'image/origin',
                tensor2imgs(img, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])[0],
                gt_bboxes[0].cpu(),
                gt_labels[0].cpu())
            add_feature_summary('feature/x', x[-1].detach().cpu().numpy())

        # RPN forward and loss
        if self.with_rpn:
            rpn_outs = self.rpn_head(x)
            rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
                                          self.train_cfg.rpn)
            rpn_losses = self.rpn_head.loss(
                *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
            losses.update(rpn_losses)

            proposal_cfg = self.train_cfg.get('rpn_proposal',
                                              self.test_cfg.rpn)
            proposal_inputs = rpn_outs + (img_meta, proposal_cfg)
            proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
        else:
            proposal_list = proposals

        # assign gts and sample proposals
        if self.with_bbox or self.with_mask:
            bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
            bbox_sampler = build_sampler(self.train_cfg.rcnn.sampler, context=self)
            num_imgs = img.size(0)
            if gt_bboxes_ignore is None:
                gt_bboxes_ignore = [None for _ in range(num_imgs)]
            sampling_results = []
            num_bgs = []
            num_fgs = []
            for i in range(num_imgs):
                assign_result = bbox_assigner.assign(proposal_list[i],
                                                     gt_bboxes[i],
                                                     gt_bboxes_ignore[i],
                                                     gt_labels[i])
                sampling_result = bbox_sampler.sample(
                    assign_result,
                    proposal_list[i],
                    gt_bboxes[i],
                    gt_labels[i],
                    feats=[lvl_feat[i][None] for lvl_feat in x])
                sampling_results.append(sampling_result)
                num_fgs.append(sampling_result.pos_inds.shape[0])
                num_bgs.append(sampling_result.neg_inds.shape[0])
            add_summary(prefix="sample_fast_rcnn_targets",
                        num_fgs=np.mean(num_fgs), num_bgs=np.mean(num_bgs))

        # bbox head forward and loss
        if self.with_bbox:
            rois = bbox2roi([res.bboxes for res in sampling_results])
            # TODO: a more flexible way to decide which feature maps to use
            bbox_feats = self.bbox_roi_extractor(
                x[:self.bbox_roi_extractor.num_inputs], rois)
            if self.with_shared_head:
                bbox_feats = self.shared_head(bbox_feats)
            cls_score, bbox_pred = self.bbox_head(bbox_feats)

            bbox_targets = self.bbox_head.get_target(sampling_results,
                                                     gt_bboxes, gt_labels,
                                                     self.train_cfg.rcnn)
            loss_bbox = self.bbox_head.loss(cls_score, bbox_pred,
                                            *bbox_targets)
            losses.update(loss_bbox)

        # mask head forward and loss
        if self.with_mask:
            if not self.share_roi_extractor:
                pos_rois = bbox2roi(
                    [res.pos_bboxes for res in sampling_results])
                mask_feats = self.mask_roi_extractor(
                    x[:self.mask_roi_extractor.num_inputs], pos_rois)
                if self.with_shared_head:
                    mask_feats = self.shared_head(mask_feats)
            else:
                pos_inds = []
                device = bbox_feats.device
                for res in sampling_results:
                    pos_inds.append(
                        torch.ones(
                            res.pos_bboxes.shape[0],
                            device=device,
                            dtype=torch.uint8))
                    pos_inds.append(
                        torch.zeros(
                            res.neg_bboxes.shape[0],
                            device=device,
                            dtype=torch.uint8))
                pos_inds = torch.cat(pos_inds)
                mask_feats = bbox_feats[pos_inds]

            if mask_feats.shape[0] > 0:
                mask_pred = self.mask_head(mask_feats)
                mask_targets = self.mask_head.get_target(
                    sampling_results, gt_masks, self.train_cfg.rcnn)
                pos_labels = torch.cat(
                    [res.pos_gt_labels for res in sampling_results])
                loss_mask = self.mask_head.loss(mask_pred, mask_targets,
                                                pos_labels)
                losses.update(loss_mask)

        return losses
Exemplo n.º 14
0
    def target_single_image(self, gt_boxes, gt_labels, feat_shape):
        """

        Args:
            gt_boxes: tensor, tensor <=> img, (num_gt, 4).
            gt_labels: tensor, tensor <=> img, (num_gt,).
            feat_shape: tuple.

        Returns:
            heatmap: tensor, tensor <=> img, (80, h, w).
            box_target: tensor, tensor <=> img, (4, h, w) or (80 * 4, h, w).
        """
        output_h, output_w = feat_shape
        heatmap_channel = self.num_fg

        heatmap = gt_boxes.new_zeros((heatmap_channel, output_h, output_w))
        fake_heatmap = gt_boxes.new_zeros((output_h, output_w))
        box_target = gt_boxes.new_ones(
            (self.wh_planes, output_h, output_w)) * -1
        wh_weight = gt_boxes.new_zeros(
            (self.wh_planes // 4, output_h, output_w))
        hm_weight = gt_boxes.new_zeros(
            (self.wh_planes // 4, output_h, output_w))
        centerness = gt_boxes.new_zeros((1, output_h, output_w))

        if self.wh_area_process == 'log':
            boxes_areas_log = bbox_areas(gt_boxes).log()
        elif self.wh_area_process == 'sqrt':
            boxes_areas_log = bbox_areas(gt_boxes).sqrt()
        else:
            boxes_areas_log = bbox_areas(gt_boxes)
        boxes_area_topk_log, boxes_ind = torch.topk(boxes_areas_log,
                                                    boxes_areas_log.size(0))

        if self.wh_area_process == 'norm':
            boxes_area_topk_log[:] = 1.

        gt_boxes = gt_boxes[boxes_ind]
        gt_labels = gt_labels[boxes_ind]

        feat_gt_boxes = gt_boxes / self.down_ratio
        feat_gt_boxes[:, [0, 2]] = torch.clamp(feat_gt_boxes[:, [0, 2]],
                                               min=0,
                                               max=output_w - 1)
        feat_gt_boxes[:, [1, 3]] = torch.clamp(feat_gt_boxes[:, [1, 3]],
                                               min=0,
                                               max=output_h - 1)
        feat_hs, feat_ws = (feat_gt_boxes[:, 3] - feat_gt_boxes[:, 1],
                            feat_gt_boxes[:, 2] - feat_gt_boxes[:, 0])

        r1 = (1 - self.center_ratio) / 2
        r2 = (1 - self.ignore_ratio) / 2

        # we calc the center and ignore area based on the gt-boxes of the origin scale
        # no peak will fall between pixels
        ct_ints = (torch.stack([(gt_boxes[:, 0] + gt_boxes[:, 2]) / 2,
                                (gt_boxes[:, 1] + gt_boxes[:, 3]) / 2],
                               dim=1) / self.down_ratio).to(torch.int)

        if self.hm_center_ratio is None:
            radiuses = torch.clamp(gaussian_radius(
                (feat_hs.ceil(), feat_ws.ceil())),
                                   min=0)
            hw_ratio_sqrt = (feat_hs / feat_ws).sqrt()
            h_radiuses = (radiuses * hw_ratio_sqrt).int()
            w_radiuses = (radiuses / hw_ratio_sqrt).int()
            if self.ct_gaussian:
                radiuses = radiuses.int()
        else:
            h_radiuses = (feat_hs * self.hm_center_ratio).int()
            w_radiuses = (feat_ws * self.hm_center_ratio).int()
            if (self.center_ratio / 2 !=
                    self.hm_center_ratio) and self.wh_heatmap:
                wh_h_radiuses = (feat_hs * self.center_ratio / 2).int()
                wh_w_radiuses = (feat_ws * self.center_ratio / 2).int()

        # calculate positive (center) regions
        ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s = calc_region(gt_boxes.transpose(
            0, 1),
                                                         r1,
                                                         use_round=False)
        ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s = [
            torch.round(x / self.down_ratio).int()
            for x in [ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s]
        ]
        ctr_x1s, ctr_x2s = [
            torch.clamp(x, max=output_w - 1) for x in [ctr_x1s, ctr_x2s]
        ]
        ctr_y1s, ctr_y2s = [
            torch.clamp(y, max=output_h - 1) for y in [ctr_y1s, ctr_y2s]
        ]
        ctr_xs_diff, ctr_ys_diff = ctr_x2s - ctr_x1s + 1, ctr_y2s - ctr_y1s + 1

        if self.fill_small:
            are_fill_small = (ctr_ys_diff <= 4) & (ctr_xs_diff <= 4)

        collide_pixels_summary = 0
        # larger boxes have lower priority than small boxes.
        for k in range(boxes_ind.shape[0]):
            cls_id = gt_labels[k] - 1
            ctr_x1, ctr_y1, ctr_x2, ctr_y2 = ctr_x1s[k], ctr_y1s[k], ctr_x2s[
                k], ctr_y2s[k]
            ctr_x_diff, ctr_y_diff = ctr_xs_diff[k], ctr_ys_diff[k]

            if self.fovea_hm or (self.fill_small and are_fill_small[k]):
                ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region(
                    feat_gt_boxes[k], r2, (output_h, output_w))

                if not self.fovea_hm:
                    ctr_x1, ctr_y1, ctr_x2, ctr_y2 = ignore_x1, ignore_y1, ignore_x2, ignore_y2

            fake_heatmap = fake_heatmap.zero_()
            if self.ct_gaussian:
                draw_umich_gaussian(fake_heatmap, ct_ints[k],
                                    radiuses[k].item())
            else:
                draw_truncate_gaussian(fake_heatmap, ct_ints[k],
                                       h_radiuses[k].item(),
                                       w_radiuses[k].item())

            if self.fovea_hm:
                # ignore_mask_box is necessary to prevent the ignore areas covering the
                # pos areas of larger boxes
                ignore_mask_box = (heatmap[cls_id, ignore_y1:ignore_y2 + 1,
                                           ignore_x1:ignore_x2 + 1] == 0)
                heatmap[cls_id, ignore_y1:ignore_y2 + 1,
                        ignore_x1:ignore_x2 + 1][ignore_mask_box] = -1
                heatmap[cls_id, ctr_y1:ctr_y2 + 1, ctr_x1:ctr_x2 + 1] = 1
                centerness[0] = torch.max(centerness[0], fake_heatmap)
            else:
                heatmap[cls_id] = torch.max(heatmap[cls_id], fake_heatmap)

            if self.wh_heatmap:
                if self.hm_center_ratio != self.center_ratio / 2:
                    fake_heatmap = fake_heatmap.zero_()
                    draw_truncate_gaussian(fake_heatmap, ct_ints[k],
                                           wh_h_radiuses[k].item(),
                                           wh_w_radiuses[k].item())
                box_target_inds = fake_heatmap > 0
            else:
                box_target_inds = torch.zeros_like(fake_heatmap,
                                                   dtype=torch.uint8)
                box_target_inds[ctr_y1:ctr_y2 + 1, ctr_x1:ctr_x2 + 1] = 1

            if self.wh_agnostic:
                collide_pixels_summary += (box_target[:, box_target_inds] >
                                           0).sum()

                box_target[:, box_target_inds] = gt_boxes[k][:, None]
            else:
                collide_pixels_summary += (box_target[(
                    cls_id * 4):(cls_id + 1) * 4, box_target_inds] > 0).sum()

                box_target[(cls_id * 4):((cls_id + 1) * 4),
                           box_target_inds] = gt_boxes[k][:, None]

            local_heatmap = fake_heatmap[box_target_inds]
            ct_div = local_heatmap.sum()
            local_heatmap *= boxes_area_topk_log[k]

            if self.wh_agnostic:
                cls_id = 0

            if self.avg_wh_weightv2 and ct_div > 0:
                wh_weight[cls_id, box_target_inds] = local_heatmap / ct_div
            elif self.avg_wh_weightv3 and ct_div > 0 and ctr_y_diff > 6 and ctr_x_diff > 6:
                wh_weight[cls_id, box_target_inds] = local_heatmap / ct_div
            elif self.avg_wh_weightv4 and ct_div > 0 and ctr_y_diff > 6 and ctr_x_diff > 6:
                wh_weight[cls_id, box_target_inds] = local_heatmap / ct_div
            else:
                wh_weight[cls_id, box_target_inds] = \
                    boxes_area_topk_log[k] / box_target_inds.sum().float()

            if self.avg_wh_weightv4:
                wh_weight[cls_id, ct_ints[k, 1].item(), ct_ints[k, 0].item()] = \
                    boxes_area_topk_log[k]

            if not self.ct_version:
                target_loc = fake_heatmap > 0.9
                hm_target_num = target_loc.sum().float()
                hm_weight[cls_id, target_loc] = 1 / (2 * (hm_target_num - 1))
                hm_weight[cls_id, ct_ints[k, 1].item(),
                          ct_ints[k, 0].item()] = 1 / 2.

        add_summary('box_target', collide_pixels=collide_pixels_summary)
        pos_pixels_summary = (box_target > 0).sum()
        add_summary('box_target', pos_pixels=pos_pixels_summary)
        add_summary('box_target',
                    collide_ratio=collide_pixels_summary /
                    pos_pixels_summary.float())

        return heatmap, box_target, centerness, wh_weight, hm_weight
Exemplo n.º 15
0
    def __call__(self, pred_hm, pred_wh, pred_reg_offset, heatmap, wh,
                 reg_mask, ind, reg_offset, center_location):
        """

        Args:
            pred_hm: tensor, (batch, 80, h, w).
            pred_wh: tensor, (batch, 2, h, w).
            pred_reg_offset: None or tensor, (batch, 2, h, w).
            heatmap: tensor, (batch, 80, h, w).
            wh: tensor, (batch, max_obj, 2).
            reg_mask: tensor, tensor <=> img, (batch, max_obj).
            ind: tensor, (batch, max_obj).
            reg_offset: tensor, (batch, max_obj, 2).
            center_location: tensor, (batch, max_obj, 2). Only useful when using GIOU.

        Returns:

        """
        H, W = pred_hm.shape[2:]
        pred_hm = torch.clamp(pred_hm.sigmoid_(), min=1e-4, max=1 - 1e-4)
        hm_loss = ct_focal_loss(pred_hm, heatmap) * self.hm_weight

        # (batch, 2, h, w) => (batch, max_obj, 2)
        pred = tranpose_and_gather_feat(pred_wh, ind)
        mask = reg_mask.unsqueeze(2).expand_as(pred).float()
        avg_factor = mask.sum() + 1e-4

        if self.use_giou:
            pred_boxes = torch.cat(
                (center_location - pred / 2., center_location + pred / 2.),
                dim=2)
            box_br = center_location + wh / 2.
            box_br[:, :, 0] = box_br[:, :, 0].clamp(max=W - 1)
            box_br[:, :, 1] = box_br[:, :, 1].clamp(max=H - 1)
            boxes = torch.cat(
                (torch.clamp(center_location - wh / 2., min=0), box_br), dim=2)
            mask_no_expand = mask[:, :, 0]
            wh_loss = giou_loss(pred_boxes, boxes,
                                mask_no_expand) * self.giou_weight
        else:
            if self.use_smooth_l1:
                wh_loss = smooth_l1_loss(
                    pred, wh, mask, avg_factor=avg_factor) * self.wh_weight
            else:
                wh_loss = weighted_l1(pred, wh, mask,
                                      avg_factor=avg_factor) * self.wh_weight

        off_loss = hm_loss.new_tensor(0.)
        if self.use_reg_offset:
            pred_reg = tranpose_and_gather_feat(pred_reg_offset, ind)
            off_loss = weighted_l1(
                pred_reg, reg_offset, mask,
                avg_factor=avg_factor) * self.off_weight

            add_summary('centernet',
                        gt_reg_off=reg_offset[reg_offset > 0].mean().item())

        if every_n_local_step(500):
            add_feature_summary('centernet/heatmap',
                                pred_hm.detach().cpu().numpy())
            add_feature_summary('centernet/gt_heatmap',
                                heatmap.detach().cpu().numpy())
            if self.use_reg_offset:
                add_feature_summary('centernet/reg_offset',
                                    pred_reg_offset.detach().cpu().numpy())

        return hm_loss, wh_loss, off_loss
Exemplo n.º 16
0
    def assign_wrt_overlaps(self, overlaps, gt_labels=None):
        """Assign w.r.t. the overlaps of bboxes with gts.

        Args:
            overlaps (Tensor): Overlaps between k gt_bboxes and n bboxes,
                shape(k, n).
            gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ).

        Returns:
            :obj:`AssignResult`: The assign result.
        """
        num_gts, num_bboxes = overlaps.size(0), overlaps.size(1)

        # 1. assign -1 by default
        assigned_gt_inds = overlaps.new_full((num_bboxes, ),
                                             -1,
                                             dtype=torch.long)

        if num_gts == 0 or num_bboxes == 0:
            # No ground truth or boxes, return empty assignment
            max_overlaps = overlaps.new_zeros((num_bboxes, ))
            if num_gts == 0:
                # No truth, assign everything to background
                assigned_gt_inds[:] = 0
            if gt_labels is None:
                assigned_labels = None
            else:
                assigned_labels = overlaps.new_zeros((num_bboxes, ),
                                                     dtype=torch.long)
            return AssignResult(num_gts,
                                assigned_gt_inds,
                                max_overlaps,
                                labels=assigned_labels)

        # for each anchor, which gt best overlaps with it
        # for each anchor, the max iou of all gts
        max_overlaps, argmax_overlaps = overlaps.max(dim=0)  # (box_num,)
        # for each gt, which anchor best overlaps with it
        # for each gt, the max iou of all proposals
        gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1)  # (gt_num,)

        # 2. assign negative: below
        if isinstance(self.neg_iou_thr, float):
            assigned_gt_inds[(max_overlaps >= 0)
                             & (max_overlaps < self.neg_iou_thr)] = 0
        elif isinstance(self.neg_iou_thr, tuple):
            assert len(self.neg_iou_thr) == 2
            assigned_gt_inds[(max_overlaps >= self.neg_iou_thr[0])
                             & (max_overlaps < self.neg_iou_thr[1])] = 0

        # 3. assign positive: above positive IoU threshold
        pos_inds = max_overlaps >= self.pos_iou_thr  # (box_num,)
        assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1  # slice

        # 4. assign fg: for each gt, proposals with highest IoU
        for i in range(num_gts):
            if gt_max_overlaps[i] >= self.min_pos_iou:
                if self.gt_max_assign_all:
                    max_iou_inds = overlaps[i, :] == gt_max_overlaps[i]
                    assigned_gt_inds[max_iou_inds] = i + 1
                else:
                    assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1

        if gt_labels is not None:
            assigned_labels = assigned_gt_inds.new_zeros((num_bboxes, ))
            pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze()
            if pos_inds.numel() > 0:
                assigned_labels[pos_inds] = gt_labels[
                    assigned_gt_inds[pos_inds] - 1]
        else:
            add_summary(
                'proposal',
                num_fgs=torch.sum(assigned_gt_inds > 0).float().item(),
                num_bgs=torch.sum(assigned_gt_inds == 0).float().item(),
                num_igs=torch.sum(assigned_gt_inds < 0).float().item())
            assigned_labels = None

        return AssignResult(num_gts,
                            assigned_gt_inds,
                            max_overlaps,
                            labels=assigned_labels)
Exemplo n.º 17
0
    def forward_split(self, x):
        # print(self.alpha_raw, self.alpha_raw.grad)
        b_d1, b_d2 = None, None
        x1, x2 = x, x
        if self.alpha_out is not None:
            split = int(self.out_channels * self.alpha_out)
            w_d1, w_d2 = self.weight[split:], self.weight[:split]
            if self.bias is not None:
                b_d1, b_d2 = self.bias[split:], self.bias[:split]
        else:
            tau = self.get_tau()
            lr_summary = self.lr_raw.detach().item(
            ) if self.auto_lr else self.lr_raw
            add_summary('tau',
                        tau=tau,
                        avg=self.alpha_raw.abs().mean().detach().item(),
                        lr_raw=lr_summary)
            alpha_tensor = self.gumbel_softmax(self.alpha_raw,
                                               temperature=tau).cumsum(-1)
            alpha_tensor = alpha_tensor[:, None].expand(
                alpha_tensor.size(0), self.expand_num).reshape(-1)
            d1_idx = alpha_tensor.bool()
            d2_idx = (1 - alpha_tensor).bool()
            w_d1 = self.weight[d1_idx]
            w_d2 = self.weight[d2_idx]
            if self.beta_s_raw is not None:
                beta_s_tensor = self.gumbel_softmax(
                    self.beta_s_raw,
                    temperature=tau,
                    summary_name=('raw_betas', 'params_betas')).cumsum(-1)
                beta_s_tensor = beta_s_tensor[:, None].expand(
                    beta_s_tensor.size(0), self.expand_beta_num).reshape(-1)
                s_idx = beta_s_tensor.bool()
                w_d1 = w_d1[:, s_idx] * beta_s_tensor[s_idx][None, :, None,
                                                             None]
                x1 = x1[:, s_idx]

                beta_l_tensor = self.gumbel_softmax(
                    self.beta_l_raw,
                    temperature=tau,
                    summary_name=('raw_betal', 'params_betal')).cumsum(-1)
                beta_l_tensor = beta_l_tensor[:, None].expand(
                    beta_l_tensor.size(0), self.expand_beta_num).reshape(-1)
                l_idx = beta_l_tensor.bool()
                l_idx_rev = l_idx.cpu().numpy()[::-1].copy()
            if self.bias is not None:
                b_d1 = self.bias[d1_idx] * alpha_tensor[d1_idx]
                b_d2 = self.bias[d2_idx] * (1 - alpha_tensor)[d2_idx]

        y = F.conv2d(x1, w_d1, b_d1, self.stride, 1, 1, 1)
        y2 = None
        if self.alpha_out is not None and self.alpha_out > 0 or \
                self.alpha_out is None and (1 - alpha_tensor).sum().detach().item() > 0:
            if self.beta_s_raw is not None:
                w_d2 = w_d2[:, l_idx] * beta_l_tensor[l_idx_rev][None, :, None,
                                                                 None]
                x2 = x2[:, l_idx]
            y2 = F.conv2d(x2, w_d2, b_d2, self.stride, 2, 2, 1)

        if self.use_bn:
            y, y2 = self.share_bn((y, y2))
        y = y * alpha_tensor[d1_idx][:, None, None]
        if y2 is not None:
            y2 = y2 * (1 - alpha_tensor)[d2_idx][:, None, None]
            y = torch.cat([y2, y], dim=1)

        return y