Example #1
0
    def forward(self, anchors, objectness, box_regression, targets=None):
        """
        Arguments:
            anchors: list[list[BoxList]]
            objectness: list[tensor]
            box_regression: list[tensor]

        Returns:
            boxlists (list[BoxList]): the post-processed anchors, after
                applying box decoding and NMS
        """
        sampled_boxes = []
        num_levels = len(objectness)
        anchors = list(zip(*anchors))
        for a, o, b in zip(anchors, objectness, box_regression):
            sampled_boxes.append(self.forward_for_single_feature_map(a, o, b))

        boxlists = list(zip(*sampled_boxes))
        boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]

        if num_levels > 1:
            boxlists = self.select_over_all_levels(boxlists)

        # append ground-truth bboxes to proposals
        if self.training and targets is not None:
            boxlists = self.add_gt_proposals(boxlists, targets)

        return boxlists
Example #2
0
def cat_boxlist_with_keypoints(boxlists):
    assert all(boxlist.has_field("keypoints") for boxlist in boxlists)

    kp = [boxlist.get_field("keypoints").keypoints for boxlist in boxlists]
    kp = cat(kp, 0)

    fields = boxlists[0].get_fields()
    fields = [field for field in fields if field != "keypoints"]

    boxlists = [boxlist.copy_with_fields(fields) for boxlist in boxlists]
    boxlists = cat_boxlist(boxlists)
    boxlists.add_field("keypoints", kp)
    return boxlists
Example #3
0
    def select_over_all_levels(self, boxlists):
        num_images = len(boxlists)
        results = []
        for i in range(num_images):
            scores = boxlists[i].get_field("scores")
            labels = boxlists[i].get_field("labels")
            boxes = boxlists[i].bbox
            boxlist = boxlists[i]
            result = []
            # skip the background
            for j in range(1, self.num_classes):
                inds = (labels == j).nonzero().view(-1)

                scores_j = scores[inds]
                boxes_j = boxes[inds, :].view(-1, 4)
                boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
                boxlist_for_class.add_field("scores", scores_j)
                boxlist_for_class = boxlist_nms(boxlist_for_class,
                                                self.nms_thresh,
                                                score_field="scores")
                num_labels = len(boxlist_for_class)
                boxlist_for_class.add_field(
                    "labels",
                    torch.full((num_labels, ),
                               j,
                               dtype=torch.int64,
                               device=scores.device))
                result.append(boxlist_for_class)

            result = cat_boxlist(result)
            number_of_detections = len(result)

            # Limit to max_per_image detections **over all classes**
            if number_of_detections > self.fpn_post_nms_top_n > 0:
                cls_scores = result.get_field("scores")
                image_thresh, _ = torch.kthvalue(
                    cls_scores.cpu(),
                    number_of_detections - self.fpn_post_nms_top_n + 1)
                keep = cls_scores >= image_thresh.item()
                keep = torch.nonzero(keep).squeeze(1)
                result = result[keep]
            results.append(result)
        return results
Example #4
0
    def filter_results(self, boxlist, num_classes):
        """Returns bounding-box detection results by thresholding on scores and
        applying non-maximum suppression (NMS).
        """
        # unwrap the boxlist to avoid additional overhead.
        # if we had multi-class NMS, we could perform this directly on the boxlist
        boxes = boxlist.bbox.reshape(-1, num_classes * 4)
        scores = boxlist.get_field("scores").reshape(-1, num_classes)

        device = scores.device
        result = []
        # Apply threshold on detection probabilities and apply NMS
        # Skip j = 0, because it's the background class
        inds_all = scores > self.score_thresh
        for j in range(1, num_classes):
            inds = inds_all[:, j].nonzero().squeeze(1)
            scores_j = scores[inds, j]
            boxes_j = boxes[inds, j * 4:(j + 1) * 4]
            boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
            boxlist_for_class.add_field("scores", scores_j)
            boxlist_for_class = boxlist_nms(boxlist_for_class, self.nms)
            num_labels = len(boxlist_for_class)
            boxlist_for_class.add_field(
                "labels",
                torch.full((num_labels, ), j, dtype=torch.int64,
                           device=device))
            result.append(boxlist_for_class)

        result = cat_boxlist(result)
        number_of_detections = len(result)

        # Limit to max_per_image detections **over all classes**
        if number_of_detections > self.detections_per_img > 0:
            cls_scores = result.get_field("scores")
            image_thresh, _ = torch.kthvalue(
                cls_scores.cpu(),
                number_of_detections - self.detections_per_img + 1)
            keep = cls_scores >= image_thresh.item()
            keep = torch.nonzero(keep).squeeze(1)
            result = result[keep]
        return result
Example #5
0
    def __call__(self, anchors, objectness, box_regression, targets):
        """
        Arguments:
            anchors (list[list[BoxList]])
            objectness (list[Tensor])
            box_regression (list[Tensor])
            targets (list[BoxList])

        Returns:
            objectness_loss (Tensor)
            box_loss (Tensor)
        """
        anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors]
        labels, regression_targets = self.prepare_targets(anchors, targets)
        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
        sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1)
        sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1)

        sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)

        objectness, box_regression = \
                concat_box_prediction_layers(objectness, box_regression)

        objectness = objectness.squeeze()

        labels = torch.cat(labels, dim=0)
        regression_targets = torch.cat(regression_targets, dim=0)

        box_loss = smooth_l1_loss(
            box_regression[sampled_pos_inds],
            regression_targets[sampled_pos_inds],
            beta=1.0 / 9,
            size_average=False,
        ) / (sampled_inds.numel())

        objectness_loss = F.binary_cross_entropy_with_logits(
            objectness[sampled_inds], labels[sampled_inds]
        )

        return objectness_loss, box_loss
Example #6
0
    def __call__(self, anchors, box_cls, box_regression, targets):
        """
        Arguments:
            anchors (list[BoxList])
            box_cls (list[Tensor])
            box_regression (list[Tensor])
            targets (list[BoxList])

        Returns:
            retinanet_cls_loss (Tensor)
            retinanet_regression_loss (Tensor
        """
        anchors = [
            cat_boxlist(anchors_per_image) for anchors_per_image in anchors
        ]
        labels, regression_targets = self.prepare_targets(anchors, targets)

        N = len(labels)
        box_cls, box_regression = \
                concat_box_prediction_layers(box_cls, box_regression)

        labels = torch.cat(labels, dim=0)
        regression_targets = torch.cat(regression_targets, dim=0)
        pos_inds = torch.nonzero(labels > 0).squeeze(1)

        retinanet_regression_loss = smooth_l1_loss(
            box_regression[pos_inds],
            regression_targets[pos_inds],
            beta=self.bbox_reg_beta,
            size_average=False,
        ) / (max(1,
                 pos_inds.numel() * self.regress_norm))

        labels = labels.int()

        retinanet_cls_loss = self.box_cls_loss_func(
            box_cls, labels) / (pos_inds.numel() + N)

        return retinanet_cls_loss, retinanet_regression_loss
Example #7
0
    def add_gt_proposals(self, proposals, targets):
        """
        Arguments:
            proposals: list[BoxList]
            targets: list[BoxList]
        """
        # Get the device we're operating on
        device = proposals[0].bbox.device

        gt_boxes = [target.copy_with_fields([]) for target in targets]

        # later cat of bbox requires all fields to be present for all bbox
        # so we need to add a dummy for objectness that's missing
        for gt_box in gt_boxes:
            gt_box.add_field("objectness",
                             torch.ones(len(gt_box), device=device))

        proposals = [
            cat_boxlist((proposal, gt_box))
            for proposal, gt_box in zip(proposals, gt_boxes)
        ]

        return proposals
Example #8
0
    def generate_feats(self, x, proposals, proposals_key=None, ver="local"):
        x = self.head(torch.cat(x, dim=0))
        if self.conv is not None:
            x = F.relu(self.conv(x))

        if proposals_key is not None:
            assert ver == "local"

            x_key = self.pooler((x[0:1, ...], ), proposals_key)
            x_key = x_key.flatten(start_dim=1)

        if proposals:
            x = self.pooler((x, ), proposals)
            x = x.flatten(start_dim=1)

        rois = cat_boxlist(proposals).bbox

        if ver == "local":
            x_key = F.relu(self.l_fcs[0](x_key))
        x = F.relu(self.l_fcs[0](x))

        if self.global_cache:
            if ver == "local":
                rois_key = proposals_key[0].bbox
                x_key = self.update_lm(x_key)
            x = self.update_lm(x)

        # distillation
        if ver in ("local", "memory"):
            x_dis = torch.cat([
                x[:self.advanced_num]
                for x in torch.split(x, self.base_num, dim=0)
            ],
                              dim=0)
            rois_dis = torch.cat([
                x[:self.advanced_num]
                for x in torch.split(rois, self.base_num, dim=0)
            ],
                                 dim=0)

        if ver == "memory":
            self.memory_cache.append({
                "rois_cur": rois_dis,
                "rois_ref": rois,
                "feats_cur": x_dis,
                "feats_ref": x
            })
            for _ in range(self.stage - 1):
                self.memory_cache.append({
                    "rois_cur": rois_dis,
                    "rois_ref": rois_dis
                })
        elif ver == "local":
            self.local_cache.append({
                "rois_cur":
                torch.cat([rois_key, rois_dis], dim=0),
                "rois_ref":
                rois,
                "feats_cur":
                torch.cat([x_key, x_dis], dim=0),
                "feats_ref":
                x
            })
            for _ in range(self.stage - 2):
                self.local_cache.append({
                    "rois_cur":
                    torch.cat([rois_key, rois_dis], dim=0),
                    "rois_ref":
                    rois_dis
                })
            self.local_cache.append({
                "rois_cur": rois_key,
                "rois_ref": rois_dis
            })
        elif ver == "global":
            self.global_cache.append({"feats": x})
Example #9
0
    def _forward_test(self, x, proposals):
        proposals, proposals_ref, x_refs = proposals

        rois_cur = cat_boxlist(proposals).bbox
        rois_ref = proposals_ref.bbox

        if self.conv is not None:
            x = self.head(x)
            x = (F.relu(self.conv(x)), )
        else:
            x = (self.head(x), )
        x = self.pooler(x, proposals)
        x = x.flatten(start_dim=1)

        position_embedding = self.cal_position_embedding(rois_cur, rois_ref)

        for i in range(self.base_stage):
            x = F.relu(self.fcs[i](x))
            attention = self.attention_module_multi_head(x,
                                                         x_refs,
                                                         position_embedding,
                                                         feat_dim=1024,
                                                         group=16,
                                                         dim=(1024, 1024,
                                                              1024),
                                                         index=i)
            x = x + attention

        if self.advanced_stage > 0:
            x_refs_adv = torch.cat([
                x[:self.advanced_num]
                for x in torch.split(x_refs, self.base_num, dim=0)
            ],
                                   dim=0)
            rois_ref_adv = torch.cat([
                x[:self.advanced_num]
                for x in torch.split(rois_ref, self.base_num, dim=0)
            ],
                                     dim=0)
            position_embedding_adv = torch.cat([
                x[..., :self.advanced_num]
                for x in torch.split(position_embedding, self.base_num, dim=-1)
            ],
                                               dim=-1)

            position_embedding = self.cal_position_embedding(
                rois_ref_adv, rois_ref)

            for i in range(self.advanced_stage):
                attention = self.attention_module_multi_head(
                    x_refs_adv,
                    x_refs,
                    position_embedding,
                    feat_dim=1024,
                    group=16,
                    dim=(1024, 1024, 1024),
                    index=i + self.base_stage)
                x_refs_adv = x_refs_adv + attention
                x_refs_adv = F.relu(self.fcs[i + self.base_stage](x_refs_adv))

            attention = self.attention_module_multi_head(
                x,
                x_refs_adv,
                position_embedding_adv,
                feat_dim=1024,
                group=16,
                dim=(1024, 1024, 1024),
                index=self.base_stage + self.advanced_stage)
            x = x + attention

        return x
Example #10
0
    def _forward_train(self, x, proposals):
        num_refs = len(x) - 1
        x = self.head(torch.cat(x, dim=0))
        if self.conv is not None:
            x = F.relu(self.conv(x))
        x, x_refs = torch.split(x, [1, num_refs], dim=0)

        proposals, proposals_cur, proposals_refs = proposals[0][0], proposals[
            1], proposals[2:]

        x, x_cur = torch.split(self.pooler((x, ), [
            cat_boxlist([proposals, proposals_cur], ignore_field=True),
        ]), [len(proposals), len(proposals_cur)],
                               dim=0)
        x, x_cur = x.flatten(start_dim=1), x_cur.flatten(start_dim=1)

        if proposals_refs:
            x_refs = self.pooler((x_refs, ), proposals_refs)
            x_refs = x_refs.flatten(start_dim=1)
            x_refs = torch.cat([x_cur, x_refs], dim=0)
        else:
            x_refs = x_cur

        rois_cur = proposals.bbox
        rois_ref = cat_boxlist([proposals_cur, *proposals_refs]).bbox
        position_embedding = self.cal_position_embedding(rois_cur, rois_ref)

        x_refs = F.relu(self.fcs[0](x_refs))

        for i in range(self.base_stage):
            x = F.relu(self.fcs[i](x))
            attention = self.attention_module_multi_head(x,
                                                         x_refs,
                                                         position_embedding,
                                                         feat_dim=1024,
                                                         group=16,
                                                         dim=(1024, 1024,
                                                              1024),
                                                         index=i)
            x = x + attention

        if self.advanced_stage > 0:
            x_refs_adv = torch.cat([
                x[:self.advanced_num]
                for x in torch.split(x_refs, self.base_num, dim=0)
            ],
                                   dim=0)
            rois_ref_adv = torch.cat([
                x[:self.advanced_num]
                for x in torch.split(rois_ref, self.base_num, dim=0)
            ],
                                     dim=0)
            position_embedding_adv = torch.cat([
                x[..., :self.advanced_num]
                for x in torch.split(position_embedding, self.base_num, dim=-1)
            ],
                                               dim=-1)

            position_embedding = self.cal_position_embedding(
                rois_ref_adv, rois_ref)

            for i in range(self.advanced_stage):
                attention = self.attention_module_multi_head(
                    x_refs_adv,
                    x_refs,
                    position_embedding,
                    feat_dim=1024,
                    group=16,
                    dim=(1024, 1024, 1024),
                    index=i + self.base_stage)
                x_refs_adv = x_refs_adv + attention
                x_refs_adv = F.relu(self.fcs[i + self.base_stage](x_refs_adv))

            attention = self.attention_module_multi_head(
                x,
                x_refs_adv,
                position_embedding_adv,
                feat_dim=1024,
                group=16,
                dim=(1024, 1024, 1024),
                index=self.base_stage + self.advanced_stage)
            x = x + attention

        return x
    def _forward_test(self, imgs, infos, targets=None):
        """
        forward for the test phase.
        :param imgs:
        :param infos:
        :param targets:
        :return:
        """
        def update_feature(img=None,
                           feats=None,
                           proposals=None,
                           proposals_feat=None):
            assert (img is not None) or (feats is not None
                                         and proposals is not None
                                         and proposals_feat is not None)

            if img is not None:
                feats = self.backbone(img)[0]
                # note here it is `imgs`! for we only need its shape, it would not cause error, but is not explicit.
                proposals = self.rpn(imgs, (feats, ), version="ref")
                proposals_feat = self.roi_heads.box.feature_extractor(
                    feats, proposals, pre_calculate=True)

            self.feats.append(feats)
            self.proposals.append(proposals[0])
            self.proposals_dis.append(proposals[0][:self.advanced_num])
            self.proposals_feat.append(proposals_feat)
            self.proposals_feat_dis.append(proposals_feat[:self.advanced_num])

        if targets is not None:
            raise ValueError("In testing mode, targets should be None")

        if infos["frame_category"] == 0:  # a new video
            self.seg_len = infos["seg_len"]
            self.end_id = 0

            self.feats = deque(maxlen=self.all_frame_interval)
            self.proposals = deque(maxlen=self.all_frame_interval)
            self.proposals_dis = deque(maxlen=self.all_frame_interval)
            self.proposals_feat = deque(maxlen=self.all_frame_interval)
            self.proposals_feat_dis = deque(maxlen=self.all_frame_interval)
            #视频帧记忆模块队列初始化
            self.roi_heads.box.feature_extractor.init_memory()
            if self.global_enable:
                #全局池初始化
                self.roi_heads.box.feature_extractor.init_global()

            feats_cur = self.backbone(imgs.tensors)[0]
            proposals_cur = self.rpn(imgs, (feats_cur, ), version="ref")
            proposals_feat_cur = self.roi_heads.box.feature_extractor(
                feats_cur, proposals_cur, pre_calculate=True)
            #初始帧时,将前12帧初始化为当前帧
            while len(self.feats) < self.key_frame_location + 1:
                update_feature(None, feats_cur, proposals_cur,
                               proposals_feat_cur)
            #更新未来帧,实时视频时,下面的条件语句跳过
            while len(self.feats) < self.all_frame_interval:
                self.end_id = min(self.end_id + 1, self.seg_len -
                                  1)  #self.end_id + 1,视频帧计数加1,但不超过视频长度
                end_filename = infos["pattern"] % self.end_id
                end_image = Image.open(infos["img_dir"] %
                                       end_filename).convert("RGB")

                end_image = infos["transforms"](end_image)
                if isinstance(end_image, tuple):
                    end_image = end_image[0]
                end_image = end_image.view(1, *end_image.shape).to(self.device)

                update_feature(end_image)

        elif infos["frame_category"] == 1:
            #非初始帧时,局部帧每次更新一帧
            self.end_id = min(self.end_id + 1, self.seg_len - 1)

            end_image = infos["ref_l"][0].tensors

            update_feature(end_image)

        # 1. update global
        if infos["ref_g"]:
            #更新全局帧,初始帧时,生成10帧,之后每次更新一帧
            for global_img in infos["ref_g"]:
                feats = self.backbone(global_img.tensors)[0]
                proposals = self.rpn(global_img, (feats, ), version="ref")
                proposals_feat = self.roi_heads.box.feature_extractor(
                    feats, proposals, pre_calculate=True)

                self.roi_heads.box.feature_extractor.update_global(
                    proposals_feat)
        #视频帧的当前帧
        feats = self.feats[self.key_frame_location]
        proposals, proposal_losses = self.rpn(imgs, (feats, ), None)

        proposals_ref = cat_boxlist(list(self.proposals))
        proposals_ref_dis = cat_boxlist(list(self.proposals_dis))
        proposals_feat_ref = torch.cat(list(self.proposals_feat), dim=0)
        proposals_feat_ref_dis = torch.cat(list(self.proposals_feat_dis),
                                           dim=0)

        proposals_list = [
            proposals, proposals_ref, proposals_ref_dis, proposals_feat_ref,
            proposals_feat_ref_dis
        ]

        if self.roi_heads:
            x, result, detector_losses = self.roi_heads(
                feats, proposals_list, None)
        else:
            result = proposals

        return result
    def _forward_test(self, imgs, infos, targets=None):
        """
        forward for the test phase.
        :param imgs:
        :param frame_category: 0 for start, 1 for normal
        :param targets:
        :return:
        """
        def update_feature(img=None,
                           feats=None,
                           proposals=None,
                           proposals_feat=None):
            assert (img is not None) or (feats is not None
                                         and proposals is not None
                                         and proposals_feat is not None)

            if img is not None:
                feats = self.backbone(img)[0]
                # note here it is `imgs`! for we only need its shape, it would not cause error, but is not explicit.
                proposals = self.rpn(imgs, (feats, ), version="ref")
                proposals_feat = self.roi_heads.box.feature_extractor(
                    feats, proposals, pre_calculate=True)

            self.feats.append(feats)
            self.proposals.append(proposals[0])
            self.proposals_feat.append(proposals_feat)
            # if self.advanced:
            #     self.proposals_adv.append(proposals[0][:self.advanced_num])
            #     self.proposals_feat_dis.append(proposals_feat[:self.advanced_num])

        if targets is not None:
            raise ValueError("In testing mode, targets should be None")

        if infos["frame_category"] == 0:  # a new video
            self.seg_len = infos["seg_len"]
            self.end_id = 0

            self.feats = deque(maxlen=self.all_frame_interval)
            self.proposals = deque(maxlen=self.all_frame_interval)
            self.proposals_feat = deque(maxlen=self.all_frame_interval)

            # if self.advanced:
            #     self.proposals_adv = deque(maxlen=self.all_frame_interval)
            #     self.proposals_feat_adv = deque(maxlen=self.all_frame_interval)

            feats_cur = self.backbone(imgs.tensors)[0]
            proposals_cur = self.rpn(imgs, (feats_cur, ), version="ref")
            proposals_feat_cur = self.roi_heads.box.feature_extractor(
                feats_cur, proposals_cur, pre_calculate=True)
            while len(self.feats) < self.key_frame_location + 1:
                update_feature(None, feats_cur, proposals_cur,
                               proposals_feat_cur)

            while len(self.feats) < self.all_frame_interval:
                self.end_id = min(self.end_id + 1, self.seg_len - 1)
                end_filename = infos["pattern"] % self.end_id
                end_image = Image.open(infos["img_dir"] %
                                       end_filename).convert("RGB")

                end_image = infos["transforms"](end_image)
                if isinstance(end_image, tuple):
                    end_image = end_image[0]
                end_image = end_image.view(1, *end_image.shape).to(self.device)

                update_feature(end_image)

        elif infos["frame_category"] == 1:
            self.end_id = min(self.end_id + 1, self.seg_len - 1)
            end_image = infos["ref"][0].tensors

            update_feature(end_image)

        feats = self.feats[self.key_frame_location]
        proposals, proposal_losses = self.rpn(imgs, (feats, ), None)

        proposals_ref = cat_boxlist(list(self.proposals))
        proposals_feat_ref = torch.cat(list(self.proposals_feat), dim=0)

        # if self.advanced:
        #     proposals_ref_adv = cat_boxlist(list(self.proposals_adv))
        #     proposals_feat_ref_adv = torch.cat(list(self.proposals_feat_adv), dim=0)
        #
        #     proposals_list = [proposals, proposals_ref, proposals_feat_ref, proposals_ref_adv, proposals_feat_ref_adv]
        # else:
        proposals_list = [proposals, proposals_ref, proposals_feat_ref]

        if self.roi_heads:
            x, result, detector_losses = self.roi_heads(
                feats, proposals_list, None)
        else:
            result = proposals

        return result