def forward(self, anchors, objectness, box_regression, targets=None): """ Arguments: anchors: list[list[BoxList]] objectness: list[tensor] box_regression: list[tensor] Returns: boxlists (list[BoxList]): the post-processed anchors, after applying box decoding and NMS """ sampled_boxes = [] num_levels = len(objectness) anchors = list(zip(*anchors)) for a, o, b in zip(anchors, objectness, box_regression): sampled_boxes.append(self.forward_for_single_feature_map(a, o, b)) boxlists = list(zip(*sampled_boxes)) boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] if num_levels > 1: boxlists = self.select_over_all_levels(boxlists) # append ground-truth bboxes to proposals if self.training and targets is not None: boxlists = self.add_gt_proposals(boxlists, targets) return boxlists
def cat_boxlist_with_keypoints(boxlists): assert all(boxlist.has_field("keypoints") for boxlist in boxlists) kp = [boxlist.get_field("keypoints").keypoints for boxlist in boxlists] kp = cat(kp, 0) fields = boxlists[0].get_fields() fields = [field for field in fields if field != "keypoints"] boxlists = [boxlist.copy_with_fields(fields) for boxlist in boxlists] boxlists = cat_boxlist(boxlists) boxlists.add_field("keypoints", kp) return boxlists
def select_over_all_levels(self, boxlists): num_images = len(boxlists) results = [] for i in range(num_images): scores = boxlists[i].get_field("scores") labels = boxlists[i].get_field("labels") boxes = boxlists[i].bbox boxlist = boxlists[i] result = [] # skip the background for j in range(1, self.num_classes): inds = (labels == j).nonzero().view(-1) scores_j = scores[inds] boxes_j = boxes[inds, :].view(-1, 4) boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") boxlist_for_class.add_field("scores", scores_j) boxlist_for_class = boxlist_nms(boxlist_for_class, self.nms_thresh, score_field="scores") num_labels = len(boxlist_for_class) boxlist_for_class.add_field( "labels", torch.full((num_labels, ), j, dtype=torch.int64, device=scores.device)) result.append(boxlist_for_class) result = cat_boxlist(result) number_of_detections = len(result) # Limit to max_per_image detections **over all classes** if number_of_detections > self.fpn_post_nms_top_n > 0: cls_scores = result.get_field("scores") image_thresh, _ = torch.kthvalue( cls_scores.cpu(), number_of_detections - self.fpn_post_nms_top_n + 1) keep = cls_scores >= image_thresh.item() keep = torch.nonzero(keep).squeeze(1) result = result[keep] results.append(result) return results
def filter_results(self, boxlist, num_classes): """Returns bounding-box detection results by thresholding on scores and applying non-maximum suppression (NMS). """ # unwrap the boxlist to avoid additional overhead. # if we had multi-class NMS, we could perform this directly on the boxlist boxes = boxlist.bbox.reshape(-1, num_classes * 4) scores = boxlist.get_field("scores").reshape(-1, num_classes) device = scores.device result = [] # Apply threshold on detection probabilities and apply NMS # Skip j = 0, because it's the background class inds_all = scores > self.score_thresh for j in range(1, num_classes): inds = inds_all[:, j].nonzero().squeeze(1) scores_j = scores[inds, j] boxes_j = boxes[inds, j * 4:(j + 1) * 4] boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") boxlist_for_class.add_field("scores", scores_j) boxlist_for_class = boxlist_nms(boxlist_for_class, self.nms) num_labels = len(boxlist_for_class) boxlist_for_class.add_field( "labels", torch.full((num_labels, ), j, dtype=torch.int64, device=device)) result.append(boxlist_for_class) result = cat_boxlist(result) number_of_detections = len(result) # Limit to max_per_image detections **over all classes** if number_of_detections > self.detections_per_img > 0: cls_scores = result.get_field("scores") image_thresh, _ = torch.kthvalue( cls_scores.cpu(), number_of_detections - self.detections_per_img + 1) keep = cls_scores >= image_thresh.item() keep = torch.nonzero(keep).squeeze(1) result = result[keep] return result
def __call__(self, anchors, objectness, box_regression, targets): """ Arguments: anchors (list[list[BoxList]]) objectness (list[Tensor]) box_regression (list[Tensor]) targets (list[BoxList]) Returns: objectness_loss (Tensor) box_loss (Tensor) """ anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors] labels, regression_targets = self.prepare_targets(anchors, targets) sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1) sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1) sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0) objectness, box_regression = \ concat_box_prediction_layers(objectness, box_regression) objectness = objectness.squeeze() labels = torch.cat(labels, dim=0) regression_targets = torch.cat(regression_targets, dim=0) box_loss = smooth_l1_loss( box_regression[sampled_pos_inds], regression_targets[sampled_pos_inds], beta=1.0 / 9, size_average=False, ) / (sampled_inds.numel()) objectness_loss = F.binary_cross_entropy_with_logits( objectness[sampled_inds], labels[sampled_inds] ) return objectness_loss, box_loss
def __call__(self, anchors, box_cls, box_regression, targets): """ Arguments: anchors (list[BoxList]) box_cls (list[Tensor]) box_regression (list[Tensor]) targets (list[BoxList]) Returns: retinanet_cls_loss (Tensor) retinanet_regression_loss (Tensor """ anchors = [ cat_boxlist(anchors_per_image) for anchors_per_image in anchors ] labels, regression_targets = self.prepare_targets(anchors, targets) N = len(labels) box_cls, box_regression = \ concat_box_prediction_layers(box_cls, box_regression) labels = torch.cat(labels, dim=0) regression_targets = torch.cat(regression_targets, dim=0) pos_inds = torch.nonzero(labels > 0).squeeze(1) retinanet_regression_loss = smooth_l1_loss( box_regression[pos_inds], regression_targets[pos_inds], beta=self.bbox_reg_beta, size_average=False, ) / (max(1, pos_inds.numel() * self.regress_norm)) labels = labels.int() retinanet_cls_loss = self.box_cls_loss_func( box_cls, labels) / (pos_inds.numel() + N) return retinanet_cls_loss, retinanet_regression_loss
def add_gt_proposals(self, proposals, targets): """ Arguments: proposals: list[BoxList] targets: list[BoxList] """ # Get the device we're operating on device = proposals[0].bbox.device gt_boxes = [target.copy_with_fields([]) for target in targets] # later cat of bbox requires all fields to be present for all bbox # so we need to add a dummy for objectness that's missing for gt_box in gt_boxes: gt_box.add_field("objectness", torch.ones(len(gt_box), device=device)) proposals = [ cat_boxlist((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes) ] return proposals
def generate_feats(self, x, proposals, proposals_key=None, ver="local"): x = self.head(torch.cat(x, dim=0)) if self.conv is not None: x = F.relu(self.conv(x)) if proposals_key is not None: assert ver == "local" x_key = self.pooler((x[0:1, ...], ), proposals_key) x_key = x_key.flatten(start_dim=1) if proposals: x = self.pooler((x, ), proposals) x = x.flatten(start_dim=1) rois = cat_boxlist(proposals).bbox if ver == "local": x_key = F.relu(self.l_fcs[0](x_key)) x = F.relu(self.l_fcs[0](x)) if self.global_cache: if ver == "local": rois_key = proposals_key[0].bbox x_key = self.update_lm(x_key) x = self.update_lm(x) # distillation if ver in ("local", "memory"): x_dis = torch.cat([ x[:self.advanced_num] for x in torch.split(x, self.base_num, dim=0) ], dim=0) rois_dis = torch.cat([ x[:self.advanced_num] for x in torch.split(rois, self.base_num, dim=0) ], dim=0) if ver == "memory": self.memory_cache.append({ "rois_cur": rois_dis, "rois_ref": rois, "feats_cur": x_dis, "feats_ref": x }) for _ in range(self.stage - 1): self.memory_cache.append({ "rois_cur": rois_dis, "rois_ref": rois_dis }) elif ver == "local": self.local_cache.append({ "rois_cur": torch.cat([rois_key, rois_dis], dim=0), "rois_ref": rois, "feats_cur": torch.cat([x_key, x_dis], dim=0), "feats_ref": x }) for _ in range(self.stage - 2): self.local_cache.append({ "rois_cur": torch.cat([rois_key, rois_dis], dim=0), "rois_ref": rois_dis }) self.local_cache.append({ "rois_cur": rois_key, "rois_ref": rois_dis }) elif ver == "global": self.global_cache.append({"feats": x})
def _forward_test(self, x, proposals): proposals, proposals_ref, x_refs = proposals rois_cur = cat_boxlist(proposals).bbox rois_ref = proposals_ref.bbox if self.conv is not None: x = self.head(x) x = (F.relu(self.conv(x)), ) else: x = (self.head(x), ) x = self.pooler(x, proposals) x = x.flatten(start_dim=1) position_embedding = self.cal_position_embedding(rois_cur, rois_ref) for i in range(self.base_stage): x = F.relu(self.fcs[i](x)) attention = self.attention_module_multi_head(x, x_refs, position_embedding, feat_dim=1024, group=16, dim=(1024, 1024, 1024), index=i) x = x + attention if self.advanced_stage > 0: x_refs_adv = torch.cat([ x[:self.advanced_num] for x in torch.split(x_refs, self.base_num, dim=0) ], dim=0) rois_ref_adv = torch.cat([ x[:self.advanced_num] for x in torch.split(rois_ref, self.base_num, dim=0) ], dim=0) position_embedding_adv = torch.cat([ x[..., :self.advanced_num] for x in torch.split(position_embedding, self.base_num, dim=-1) ], dim=-1) position_embedding = self.cal_position_embedding( rois_ref_adv, rois_ref) for i in range(self.advanced_stage): attention = self.attention_module_multi_head( x_refs_adv, x_refs, position_embedding, feat_dim=1024, group=16, dim=(1024, 1024, 1024), index=i + self.base_stage) x_refs_adv = x_refs_adv + attention x_refs_adv = F.relu(self.fcs[i + self.base_stage](x_refs_adv)) attention = self.attention_module_multi_head( x, x_refs_adv, position_embedding_adv, feat_dim=1024, group=16, dim=(1024, 1024, 1024), index=self.base_stage + self.advanced_stage) x = x + attention return x
def _forward_train(self, x, proposals): num_refs = len(x) - 1 x = self.head(torch.cat(x, dim=0)) if self.conv is not None: x = F.relu(self.conv(x)) x, x_refs = torch.split(x, [1, num_refs], dim=0) proposals, proposals_cur, proposals_refs = proposals[0][0], proposals[ 1], proposals[2:] x, x_cur = torch.split(self.pooler((x, ), [ cat_boxlist([proposals, proposals_cur], ignore_field=True), ]), [len(proposals), len(proposals_cur)], dim=0) x, x_cur = x.flatten(start_dim=1), x_cur.flatten(start_dim=1) if proposals_refs: x_refs = self.pooler((x_refs, ), proposals_refs) x_refs = x_refs.flatten(start_dim=1) x_refs = torch.cat([x_cur, x_refs], dim=0) else: x_refs = x_cur rois_cur = proposals.bbox rois_ref = cat_boxlist([proposals_cur, *proposals_refs]).bbox position_embedding = self.cal_position_embedding(rois_cur, rois_ref) x_refs = F.relu(self.fcs[0](x_refs)) for i in range(self.base_stage): x = F.relu(self.fcs[i](x)) attention = self.attention_module_multi_head(x, x_refs, position_embedding, feat_dim=1024, group=16, dim=(1024, 1024, 1024), index=i) x = x + attention if self.advanced_stage > 0: x_refs_adv = torch.cat([ x[:self.advanced_num] for x in torch.split(x_refs, self.base_num, dim=0) ], dim=0) rois_ref_adv = torch.cat([ x[:self.advanced_num] for x in torch.split(rois_ref, self.base_num, dim=0) ], dim=0) position_embedding_adv = torch.cat([ x[..., :self.advanced_num] for x in torch.split(position_embedding, self.base_num, dim=-1) ], dim=-1) position_embedding = self.cal_position_embedding( rois_ref_adv, rois_ref) for i in range(self.advanced_stage): attention = self.attention_module_multi_head( x_refs_adv, x_refs, position_embedding, feat_dim=1024, group=16, dim=(1024, 1024, 1024), index=i + self.base_stage) x_refs_adv = x_refs_adv + attention x_refs_adv = F.relu(self.fcs[i + self.base_stage](x_refs_adv)) attention = self.attention_module_multi_head( x, x_refs_adv, position_embedding_adv, feat_dim=1024, group=16, dim=(1024, 1024, 1024), index=self.base_stage + self.advanced_stage) x = x + attention return x
def _forward_test(self, imgs, infos, targets=None): """ forward for the test phase. :param imgs: :param infos: :param targets: :return: """ def update_feature(img=None, feats=None, proposals=None, proposals_feat=None): assert (img is not None) or (feats is not None and proposals is not None and proposals_feat is not None) if img is not None: feats = self.backbone(img)[0] # note here it is `imgs`! for we only need its shape, it would not cause error, but is not explicit. proposals = self.rpn(imgs, (feats, ), version="ref") proposals_feat = self.roi_heads.box.feature_extractor( feats, proposals, pre_calculate=True) self.feats.append(feats) self.proposals.append(proposals[0]) self.proposals_dis.append(proposals[0][:self.advanced_num]) self.proposals_feat.append(proposals_feat) self.proposals_feat_dis.append(proposals_feat[:self.advanced_num]) if targets is not None: raise ValueError("In testing mode, targets should be None") if infos["frame_category"] == 0: # a new video self.seg_len = infos["seg_len"] self.end_id = 0 self.feats = deque(maxlen=self.all_frame_interval) self.proposals = deque(maxlen=self.all_frame_interval) self.proposals_dis = deque(maxlen=self.all_frame_interval) self.proposals_feat = deque(maxlen=self.all_frame_interval) self.proposals_feat_dis = deque(maxlen=self.all_frame_interval) #视频帧记忆模块队列初始化 self.roi_heads.box.feature_extractor.init_memory() if self.global_enable: #全局池初始化 self.roi_heads.box.feature_extractor.init_global() feats_cur = self.backbone(imgs.tensors)[0] proposals_cur = self.rpn(imgs, (feats_cur, ), version="ref") proposals_feat_cur = self.roi_heads.box.feature_extractor( feats_cur, proposals_cur, pre_calculate=True) #初始帧时,将前12帧初始化为当前帧 while len(self.feats) < self.key_frame_location + 1: update_feature(None, feats_cur, proposals_cur, proposals_feat_cur) #更新未来帧,实时视频时,下面的条件语句跳过 while len(self.feats) < self.all_frame_interval: self.end_id = min(self.end_id + 1, self.seg_len - 1) #self.end_id + 1,视频帧计数加1,但不超过视频长度 end_filename = infos["pattern"] % self.end_id end_image = Image.open(infos["img_dir"] % end_filename).convert("RGB") end_image = infos["transforms"](end_image) if isinstance(end_image, tuple): end_image = end_image[0] end_image = end_image.view(1, *end_image.shape).to(self.device) update_feature(end_image) elif infos["frame_category"] == 1: #非初始帧时,局部帧每次更新一帧 self.end_id = min(self.end_id + 1, self.seg_len - 1) end_image = infos["ref_l"][0].tensors update_feature(end_image) # 1. update global if infos["ref_g"]: #更新全局帧,初始帧时,生成10帧,之后每次更新一帧 for global_img in infos["ref_g"]: feats = self.backbone(global_img.tensors)[0] proposals = self.rpn(global_img, (feats, ), version="ref") proposals_feat = self.roi_heads.box.feature_extractor( feats, proposals, pre_calculate=True) self.roi_heads.box.feature_extractor.update_global( proposals_feat) #视频帧的当前帧 feats = self.feats[self.key_frame_location] proposals, proposal_losses = self.rpn(imgs, (feats, ), None) proposals_ref = cat_boxlist(list(self.proposals)) proposals_ref_dis = cat_boxlist(list(self.proposals_dis)) proposals_feat_ref = torch.cat(list(self.proposals_feat), dim=0) proposals_feat_ref_dis = torch.cat(list(self.proposals_feat_dis), dim=0) proposals_list = [ proposals, proposals_ref, proposals_ref_dis, proposals_feat_ref, proposals_feat_ref_dis ] if self.roi_heads: x, result, detector_losses = self.roi_heads( feats, proposals_list, None) else: result = proposals return result
def _forward_test(self, imgs, infos, targets=None): """ forward for the test phase. :param imgs: :param frame_category: 0 for start, 1 for normal :param targets: :return: """ def update_feature(img=None, feats=None, proposals=None, proposals_feat=None): assert (img is not None) or (feats is not None and proposals is not None and proposals_feat is not None) if img is not None: feats = self.backbone(img)[0] # note here it is `imgs`! for we only need its shape, it would not cause error, but is not explicit. proposals = self.rpn(imgs, (feats, ), version="ref") proposals_feat = self.roi_heads.box.feature_extractor( feats, proposals, pre_calculate=True) self.feats.append(feats) self.proposals.append(proposals[0]) self.proposals_feat.append(proposals_feat) # if self.advanced: # self.proposals_adv.append(proposals[0][:self.advanced_num]) # self.proposals_feat_dis.append(proposals_feat[:self.advanced_num]) if targets is not None: raise ValueError("In testing mode, targets should be None") if infos["frame_category"] == 0: # a new video self.seg_len = infos["seg_len"] self.end_id = 0 self.feats = deque(maxlen=self.all_frame_interval) self.proposals = deque(maxlen=self.all_frame_interval) self.proposals_feat = deque(maxlen=self.all_frame_interval) # if self.advanced: # self.proposals_adv = deque(maxlen=self.all_frame_interval) # self.proposals_feat_adv = deque(maxlen=self.all_frame_interval) feats_cur = self.backbone(imgs.tensors)[0] proposals_cur = self.rpn(imgs, (feats_cur, ), version="ref") proposals_feat_cur = self.roi_heads.box.feature_extractor( feats_cur, proposals_cur, pre_calculate=True) while len(self.feats) < self.key_frame_location + 1: update_feature(None, feats_cur, proposals_cur, proposals_feat_cur) while len(self.feats) < self.all_frame_interval: self.end_id = min(self.end_id + 1, self.seg_len - 1) end_filename = infos["pattern"] % self.end_id end_image = Image.open(infos["img_dir"] % end_filename).convert("RGB") end_image = infos["transforms"](end_image) if isinstance(end_image, tuple): end_image = end_image[0] end_image = end_image.view(1, *end_image.shape).to(self.device) update_feature(end_image) elif infos["frame_category"] == 1: self.end_id = min(self.end_id + 1, self.seg_len - 1) end_image = infos["ref"][0].tensors update_feature(end_image) feats = self.feats[self.key_frame_location] proposals, proposal_losses = self.rpn(imgs, (feats, ), None) proposals_ref = cat_boxlist(list(self.proposals)) proposals_feat_ref = torch.cat(list(self.proposals_feat), dim=0) # if self.advanced: # proposals_ref_adv = cat_boxlist(list(self.proposals_adv)) # proposals_feat_ref_adv = torch.cat(list(self.proposals_feat_adv), dim=0) # # proposals_list = [proposals, proposals_ref, proposals_feat_ref, proposals_ref_adv, proposals_feat_ref_adv] # else: proposals_list = [proposals, proposals_ref, proposals_feat_ref] if self.roi_heads: x, result, detector_losses = self.roi_heads( feats, proposals_list, None) else: result = proposals return result