Ejemplo n.º 1
0
    def forward(self, image, im_info, gt_boxes=None):
        image = self.preprocess_image(image)
        features = self.backbone(image)
        features = [features[f] for f in self.in_features]

        box_logits, box_offsets, box_ctrness = self.head(features)

        box_logits_list = [
            _.transpose(0, 2, 3, 1).reshape(image.shape[0], -1,
                                            self.cfg.num_classes)
            for _ in box_logits
        ]
        box_offsets_list = [
            _.transpose(0, 2, 3, 1).reshape(image.shape[0], -1, 4)
            for _ in box_offsets
        ]
        box_ctrness_list = [
            _.transpose(0, 2, 3, 1).reshape(image.shape[0], -1, 1)
            for _ in box_ctrness
        ]

        anchors_list = self.anchor_generator(features)

        all_level_box_logits = F.concat(box_logits_list, axis=1)
        all_level_box_offsets = F.concat(box_offsets_list, axis=1)
        all_level_box_ctrness = F.concat(box_ctrness_list, axis=1)

        if self.training:
            gt_labels, gt_offsets, gt_ctrness = self.get_ground_truth(
                anchors_list,
                gt_boxes,
                im_info[:, 4].astype(np.int32),
            )

            all_level_box_logits = all_level_box_logits.reshape(
                -1, self.cfg.num_classes)
            all_level_box_offsets = all_level_box_offsets.reshape(-1, 4)
            all_level_box_ctrness = all_level_box_ctrness.flatten()

            gt_labels = gt_labels.flatten()
            gt_offsets = gt_offsets.reshape(-1, 4)
            gt_ctrness = gt_ctrness.flatten()

            valid_mask = gt_labels >= 0
            fg_mask = gt_labels > 0
            num_fg = fg_mask.sum()
            sum_ctr = gt_ctrness[fg_mask].sum()
            # add detach() to avoid syncing across ranks in backward
            num_fg = layers.all_reduce_mean(num_fg).detach()
            sum_ctr = layers.all_reduce_mean(sum_ctr).detach()

            gt_targets = F.zeros_like(all_level_box_logits)
            gt_targets[fg_mask, gt_labels[fg_mask] - 1] = 1

            loss_cls = layers.sigmoid_focal_loss(
                all_level_box_logits[valid_mask],
                gt_targets[valid_mask],
                alpha=self.cfg.focal_loss_alpha,
                gamma=self.cfg.focal_loss_gamma,
            ).sum() / F.maximum(num_fg, 1)

            loss_bbox = (layers.iou_loss(
                all_level_box_offsets[fg_mask],
                gt_offsets[fg_mask],
                box_mode="ltrb",
                loss_type=self.cfg.iou_loss_type,
            ) * gt_ctrness[fg_mask]).sum() / F.maximum(
                sum_ctr, 1e-5) * self.cfg.loss_bbox_weight

            loss_ctr = layers.binary_cross_entropy(
                all_level_box_ctrness[fg_mask],
                gt_ctrness[fg_mask],
            ).sum() / F.maximum(num_fg, 1)

            total = loss_cls + loss_bbox + loss_ctr
            loss_dict = {
                "total_loss": total,
                "loss_cls": loss_cls,
                "loss_bbox": loss_bbox,
                "loss_ctr": loss_ctr,
            }
            self.cfg.losses_keys = list(loss_dict.keys())
            return loss_dict
        else:
            # currently not support multi-batch testing
            assert image.shape[0] == 1

            all_level_anchors = F.concat(anchors_list, axis=0)
            pred_boxes = self.point_coder.decode(all_level_anchors,
                                                 all_level_box_offsets[0])
            pred_boxes = pred_boxes.reshape(-1, 4)

            scale_w = im_info[0, 1] / im_info[0, 3]
            scale_h = im_info[0, 0] / im_info[0, 2]
            pred_boxes = pred_boxes / F.concat(
                [scale_w, scale_h, scale_w, scale_h], axis=0)
            clipped_boxes = layers.get_clipped_boxes(pred_boxes,
                                                     im_info[0, 2:4]).reshape(
                                                         -1, 4)
            pred_score = F.sqrt(
                F.sigmoid(all_level_box_logits) *
                F.sigmoid(all_level_box_ctrness))[0]
            return pred_score, clipped_boxes
Ejemplo n.º 2
0
    def forward(self, image, im_info, gt_boxes=None):
        image = self.preprocess_image(image)
        features = self.backbone(image)
        features = [features[f] for f in self.in_features]

        box_logits, box_offsets = self.head(features)

        box_logits_list = [
            _.transpose(0, 2, 3, 1).reshape(image.shape[0], -1,
                                            self.cfg.num_classes)
            for _ in box_logits
        ]
        box_offsets_list = [
            _.transpose(0, 2, 3, 1).reshape(image.shape[0], -1, 4)
            for _ in box_offsets
        ]

        anchors_list = self.anchor_generator(features)

        all_level_box_logits = F.concat(box_logits_list, axis=1)
        all_level_box_offsets = F.concat(box_offsets_list, axis=1)
        all_level_anchors = F.concat(anchors_list, axis=0)

        if self.training:
            gt_labels, gt_offsets = self.get_ground_truth(
                all_level_anchors,
                gt_boxes,
                im_info[:, 4].astype(np.int32),
            )

            all_level_box_logits = all_level_box_logits.reshape(
                -1, self.cfg.num_classes)
            all_level_box_offsets = all_level_box_offsets.reshape(-1, 4)

            gt_labels = gt_labels.flatten()
            gt_offsets = gt_offsets.reshape(-1, 4)

            valid_mask = gt_labels >= 0
            fg_mask = gt_labels > 0
            num_fg = fg_mask.sum()

            gt_targets = F.zeros_like(all_level_box_logits)
            gt_targets[fg_mask, gt_labels[fg_mask] - 1] = 1

            loss_cls = layers.sigmoid_focal_loss(
                all_level_box_logits[valid_mask],
                gt_targets[valid_mask],
                alpha=self.cfg.focal_loss_alpha,
                gamma=self.cfg.focal_loss_gamma,
            ).sum() / F.maximum(num_fg, 1)

            loss_bbox = layers.smooth_l1_loss(
                all_level_box_offsets[fg_mask],
                gt_offsets[fg_mask],
                beta=self.cfg.smooth_l1_beta,
            ).sum() / F.maximum(num_fg, 1) * self.cfg.loss_bbox_weight

            total = loss_cls + loss_bbox
            loss_dict = {
                "total_loss": total,
                "loss_cls": loss_cls,
                "loss_bbox": loss_bbox,
            }
            self.cfg.losses_keys = list(loss_dict.keys())
            return loss_dict
        else:
            # currently not support multi-batch testing
            assert image.shape[0] == 1

            pred_boxes = self.box_coder.decode(all_level_anchors,
                                               all_level_box_offsets[0])
            pred_boxes = pred_boxes.reshape(-1, 4)

            scale_w = im_info[0, 1] / im_info[0, 3]
            scale_h = im_info[0, 0] / im_info[0, 2]
            pred_boxes = pred_boxes / F.concat(
                [scale_w, scale_h, scale_w, scale_h], axis=0)
            clipped_boxes = layers.get_clipped_boxes(pred_boxes,
                                                     im_info[0, 2:4]).reshape(
                                                         -1, 4)
            pred_score = F.sigmoid(all_level_box_logits)[0]
            return pred_score, clipped_boxes